From 345a3dcdda91e7508314024e465d0feb7e1c789d Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 15 Dec 2025 08:38:02 +0000 Subject: [PATCH 1/5] Integrate Automated QDQ placement tool - part 2 Signed-off-by: Will Guo --- modelopt/onnx/op_types.py | 19 + .../onnx/quantization/autotune/__init__.py | 144 + modelopt/onnx/quantization/autotune/common.py | 1371 ++++++++++ .../quantization/autotune/region_inspect.py | 201 ++ .../quantization/autotune/region_pattern.py | 669 +++++ .../quantization/autotune/region_search.py | 2348 +++++++++++++++++ modelopt/onnx/quantization/graph_utils.py | 17 + modelopt/onnx/quantization/qdq_utils.py | 44 + .../autotune/test_pattern_cache.py | 237 ++ .../autotune/test_region_pattern.py | 410 +++ .../autotune/test_region_search.py | 420 +++ 11 files changed, 5880 insertions(+) create mode 100644 modelopt/onnx/quantization/autotune/__init__.py create mode 100644 modelopt/onnx/quantization/autotune/common.py create mode 100644 modelopt/onnx/quantization/autotune/region_inspect.py create mode 100644 modelopt/onnx/quantization/autotune/region_pattern.py create mode 100644 modelopt/onnx/quantization/autotune/region_search.py create mode 100644 tests/unit/onnx/quantization/autotune/test_pattern_cache.py create mode 100644 tests/unit/onnx/quantization/autotune/test_region_pattern.py create mode 100644 tests/unit/onnx/quantization/autotune/test_region_search.py diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index cc94a221f..0352e7106 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -303,3 +303,22 @@ def is_data_dependent_shape_op(op_type: str): "NonZero", "RoiAlign", ] + + +def get_symmetric_ops(): + """Returns set of commutative/symmetric operations where operand order doesn't matter.""" + return { + "Add", + "Mul", + "And", + "Or", + "Xor", + "Equal", + "Max", + "Min", + "Sum", + "Mean", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + } diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py new file mode 100644 index 000000000..a65b2ccba --- /dev/null +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -0,0 +1,144 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pattern-Based Q/DQ Autotuning for ONNX Models. + +This package provides automated optimization of Quantize/Dequantize (Q/DQ) node placement +in ONNX computation graphs to minimize TensorRT inference latency. It uses pattern-based +region analysis to efficiently explore and optimize Q/DQ insertion strategies. + +**Key Features:** + +- **Automated Region Discovery**: Hierarchical decomposition of computation graphs into + LEAF and COMPOSITE regions with automatic pattern identification + +- **Pattern-Based Optimization**: Groups structurally-similar regions and optimizes them + together, making the process efficient and consistent + +- **TensorRT Performance Measurement**: Direct integration with TensorRT Python API for + accurate latency profiling of each Q/DQ configuration + +- **State Management**: Checkpoint/resume capability for long-running optimizations with + incremental state saving after each region + +- **Pattern Cache**: Warm-start optimization using learned schemes from previous runs, + enabling transfer learning across models + +**Core Components:** + +Autotuner Classes: + - QDQAutotuner: Main autotuner with automatic hierarchical region discovery + - QDQAutotunerBase: Base class for custom region identification strategies + +Region Management: + - Region: Hierarchical subgraph representation (nodes + children) + - RegionType: Enumeration (LEAF, COMPOSITE, ROOT) + - CombinedRegionSearch: Two-phase region discovery (partitioning + refinement) + - RegionPattern: Structural pattern analysis and matching for region grouping + +Q/DQ Insertion Points: + - InsertionScheme: Collection of Q/DQ insertion points for a region pattern + - NodeInputInsertionPoint: Q/DQ insertion at specific node inputs + - ChildRegionInputInsertionPoint: Q/DQ insertion at child region input boundaries + - RegionOutputInsertionPoint: Q/DQ insertion at region output boundaries + +Configuration & State: + - Config: Autotuning parameters (quant type, thresholds, verbosity) + - PatternCache: Top-performing schemes indexed by pattern (warm-start) + - PatternSchemes: Scheme collection and measurement results for a pattern + +Benchmarking: + - Benchmark: Abstract base class for model benchmarking + - TensorRTPyBenchmark: Benchmark using TensorRT Python API (recommended) + - TrtExecBenchmark: Benchmark using trtexec command-line tool (legacy) + +**Quick Start:** + + >>> from modelopt.onnx.quantization.autotune import QDQAutotuner, Config + >>> import onnx + >>> # Load model and initialize autotuner + >>> model = onnx.load("model.onnx") + >>> autotuner = QDQAutotuner(model) + >>> # Configure autotuning parameters + >>> config = Config(default_quant_type="int8") + >>> autotuner.initialize(config) + >>> # Generate and test Q/DQ schemes + >>> # (see workflows.region_pattern_autotuning_workflow for complete example) + +**Command-Line Interface:** + + The package can be run directly as a module: + + $ python -m modelopt.onnx.quantization.autotune --model model.onnx --output ./output + $ python -m modelopt.onnx.quantization.autotune --model model.onnx --quant-type fp8 + +**See Also:** + + - workflows.region_pattern_autotuning_workflow: Complete end-to-end optimization + - QDQAutotuner: Main autotuner class documentation + - RegionPattern: Pattern matching and signature computation +""" + +# Core data structures +from .common import ( + AutotunerError, + AutotunerNotInitializedError, + Config, + InsertionScheme, + InvalidSchemeError, + PatternCache, + PatternSchemes, + Region, + RegionError, + RegionType, +) + +# Insertion points (from dedicated module) +from .insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, + ResolvedInsertionPoint, +) + +# Pattern analysis +from .region_pattern import RegionPattern + +# Region search +from .region_search import CombinedRegionSearch + +# Public API +__all__ = [ + # Exceptions + "AutotunerError", + "AutotunerNotInitializedError", + "ChildRegionInputInsertionPoint", + "CombinedRegionSearch", + # Configuration and state + "Config", + # Q/DQ insertion + "InsertionScheme", + "InvalidSchemeError", + "NodeInputInsertionPoint", + "ResolvedInsertionPoint", + "PatternCache", + "PatternSchemes", + # Region classes + "Region", + "RegionError", + "RegionOutputInsertionPoint", + "RegionPattern", + "RegionType", +] diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py new file mode 100644 index 000000000..42b63c251 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/common.py @@ -0,0 +1,1371 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common data structures and types for the QDQ Autotuner. + +This module provides the foundational classes used throughout the autotuner: + +**Exceptions:** +- Region-related: RegionError +- Autotuner-related: AutotunerError, AutotunerNotInitializedError, InvalidSchemeError + +**Region Hierarchy:** +- Region: Hierarchical subgraph representation with parent/child relationships +- RegionType: Enumeration for LEAF, COMPOSITE, and ROOT regions + +**Q/DQ Insertion Specifications:** +- InsertionScheme: Collection of insertion points with performance metrics + +**Scheme Management:** +- PatternSchemes: Multiple insertion schemes for a pattern (applies to all matching regions) +- PatternCache: Collection of top schemes for multiple patterns, used as autotuning seeds + +**Configuration:** +- Config: Autotuning parameters and Q/DQ default values +""" + +import hashlib +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import TYPE_CHECKING, Any, Optional + +import onnx_graphsurgeon as gs +import yaml + +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, + ResolvedInsertionPoint, +) + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + +# Module logger +logger = logging.getLogger(__name__) + + +# Region-related Exceptions +class RegionError(Exception): + """Base exception for region-related errors.""" + + +# Autotuner-related Exceptions +class AutotunerError(Exception): + """Base exception for autotuner-related errors.""" + + +class AutotunerNotInitializedError(AutotunerError): + """Exception raised when autotuner is used without initialization.""" + + +class InvalidSchemeError(AutotunerError): + """Exception raised when an invalid scheme is referenced.""" + + +class RegionType(Enum): + """Region type enumeration for hierarchical graph structure. + + - LEAF: Atomic region containing direct nodes with no child regions + - COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes) + - ROOT: Top-level region encompassing the entire computation graph + """ + + LEAF = "LEAF" + COMPOSITE = "COMPOSITE" + ROOT = "ROOT" + + +class Region: + """Hierarchical subgraph region in an ONNX computation graph. + + A Region represents a cohesive subgraph with well-defined boundaries, supporting: + + **Hierarchical Structure:** + - Parent/child relationships forming a multi-level hierarchy + - LEAF regions contain only direct nodes + - COMPOSITE regions contain child regions (and optionally direct nodes) + - ROOT regions encompass the entire graph + + **Node Management:** + - Direct nodes: Operations directly in this region (not in children) + - Recursive nodes: All operations including those in descendant regions + + **Boundary Tracking:** + - Input tensors: Data entering the region from outside + - Output tensors: Data leaving the region to outside consumers + + **Pattern Matching:** + - Regions with identical structure share the same pattern signature + - Pattern-based optimization applies schemes to all matching regions + + Regions are the fundamental unit for Q/DQ insertion and optimization. + """ + + def __init__(self, region_id: int, level: int, region_type: RegionType): + """Initialize a new region. + + Args: + region_id: Unique identifier within the region hierarchy + level: Hierarchical level (0 = leaf, higher = more composite) + region_type: Type classification (LEAF, COMPOSITE, or ROOT) + """ + self.id = region_id + self.level = level + self.type = region_type + self.parent: Region | None = None + self.children: list[Region] = [] + self.nodes: set[int] = set() + self.inputs: list[str] = [] + self.outputs: list[str] = [] + + # ========================================================================= + # Basic Accessors + # ========================================================================= + + def get_id(self) -> int: + """Get region ID.""" + return self.id + + def set_id(self, region_id: int) -> None: + """Set region ID (for RegionBuilder use).""" + self.id = region_id + + def get_level(self) -> int: + """Get region level in hierarchy.""" + return self.level + + def set_level(self, level: int) -> None: + """Set region level in hierarchy (for RegionBuilder use).""" + self.level = level + + def get_type(self) -> RegionType: + """Get region type.""" + return self.type + + def set_type(self, region_type: RegionType) -> None: + """Set region type (for RegionBuilder use).""" + self.type = region_type + + # ========================================================================= + # Hierarchy Management + # ========================================================================= + + def get_parent(self) -> Optional["Region"]: + """Get parent region.""" + return self.parent + + def set_parent(self, parent: Optional["Region"]) -> None: + """Set parent region.""" + self.parent = parent + + def get_children(self) -> list["Region"]: + """Get all child regions.""" + return self.children + + def remove_child(self, child: "Region") -> bool: + """Remove a child region from this region's children list. + + Args: + child: The child region to remove + + Returns: + True if child was found and removed, False otherwise + """ + child_id = child.get_id() + initial_count = len(self.children) + self.children = [c for c in self.children if c.get_id() != child_id] + removed = len(self.children) < initial_count + + if removed and child.parent and child.parent.get_id() == self.id: + child.set_parent(None) + + return removed + + def add_child(self, child: "Region") -> None: + """Add a child sub-region.""" + # Prevent adding self as child + if child.get_id() == self.id: + logger.warning(f"Cannot add region {self.id} as its own child") + return + + # Prevent creating cycles: check if self is already a descendant of child + if self._is_descendant_of(child): + logger.warning( + f"Cycle detected: region {self.id} is already a descendant of region {child.get_id()}" + ) + return + + # Check if child already has a different parent + if child.parent is not None and child.parent.get_id() != self.id: + old_parent_id = child.parent.get_id() + logger.debug( + f"Re-parenting region {child.get_id()}: moving from parent {old_parent_id} to {self.id}" + ) + # Remove from old parent to maintain tree structure + child.parent.remove_child(child) + + # Check if child is already in children list + if any(c.get_id() == child.get_id() for c in self.children): + logger.debug(f"Region {child.get_id()} already child of {self.id}") + return + + self.children.append(child) + child.set_parent(self) + + def _is_descendant_of(self, potential_ancestor: "Region") -> bool: + """Check if this region is a descendant of potential_ancestor.""" + visited = set() + current = self.parent + while current: + if current.get_id() in visited: + # Already visited, there's a cycle in parents + return False + visited.add(current.get_id()) + if current.get_id() == potential_ancestor.get_id(): + return True + current = current.parent + return False + + # ========================================================================= + # Node Management + # ========================================================================= + + def add_node(self, node_index: int) -> None: + """Add a node index to this region.""" + self.nodes.add(node_index) + + def add_nodes(self, node_indices: list[int]) -> None: + """Add multiple node indices to this region.""" + self.nodes.update(node_indices) + + def get_nodes(self) -> set[int]: + """Get direct node indices in this region only. + + Returns only nodes directly owned by this region, excluding nodes + in child regions. Use get_all_nodes_recursive() for complete coverage. + + Returns: + Set of node indices (absolute positions in the graph) + """ + return self.nodes + + def get_all_nodes_recursive(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants. + + Traverses the entire subtree rooted at this region, collecting nodes + from this region and all child regions recursively. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Set of all node indices in this region and its descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + logger.warning(f"Cycle detected in region {self.id} during node traversal") + return set() + + _visited.add(self.id) + all_nodes = set(self.nodes) + for child in self.children: + all_nodes.update(child.get_all_nodes_recursive(_visited)) + return all_nodes + + def contains_node(self, node_index: int) -> bool: + """Check if region contains a specific node (direct only).""" + return node_index in self.nodes + + def contains_node_recursive(self, node_index: int, _visited: set[int] | None = None) -> bool: + """Check if region contains a node recursively.""" + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + return False + + _visited.add(self.id) + + if self.contains_node(node_index): + return True + return any(child.contains_node_recursive(node_index, _visited) for child in self.children) + + # ========================================================================= + # Input/Output Management + # ========================================================================= + + def add_input(self, tensor_name: str) -> None: + """Add an input tensor name.""" + if tensor_name not in self.inputs: + self.inputs.append(tensor_name) + + def add_output(self, tensor_name: str) -> None: + """Add an output tensor name.""" + if tensor_name not in self.outputs: + self.outputs.append(tensor_name) + + def get_inputs(self) -> list[str]: + """Get region input tensors.""" + return self.inputs + + def get_outputs(self) -> list[str]: + """Get region output tensors.""" + return self.outputs + + # ========================================================================= + # Size and Query Methods + # ========================================================================= + + def get_size(self) -> int: + """Get the number of direct nodes in this region. + + Returns: + Count of nodes directly in this region (excludes child regions) + """ + return len(self.nodes) + + def get_total_size(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants. + + Computes the sum of nodes in this region and all child regions, + providing the total footprint of the region subtree. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Total number of nodes in this region and all descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + logger.warning(f"Cycle detected in region {self.id} during size calculation") + return len(self.nodes) + + _visited.add(self.id) + total = len(self.nodes) + for child in self.children: + total += child.get_total_size(_visited) + return total + + # ========================================================================= + # Region Operations + # ========================================================================= + + def merge(self, other: "Region") -> None: + """Merge another region into this one. + + Combines the nodes and children from the other region into this region. + The other region's children become children of this region, updating + their parent references accordingly. + + Args: + other: Region to merge into this one + """ + if not other: + return + # Merge direct nodes + self.nodes.update(other.nodes) + # Merge children (updates their parent references) + for child in other.children: + self.add_child(child) + + # ========================================================================= + # String Representation + # ========================================================================= + + def to_string(self) -> str: + """Print region information for debugging.""" + type_str = self.type.value + return ( + f"Region[id={self.id}, level={self.level}, type={type_str}, " + f"nodes={len(self.nodes)}, children={len(self.children)}, " + f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" + ) + + def __str__(self) -> str: + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def compute_structural_signature(self, graph: gs.Graph) -> str: + """Compute deterministic structural signature for pattern matching. + + Creates a signature that uniquely identifies the region's topology, + node operations, and hierarchical structure. Regions with identical + signatures can share Q/DQ insertion schemes. + + The signature captures: + - Node operation types and key parameters + - Hierarchical structure (child regions) + - Deterministic ordering (sorted for consistency) + + Args: + graph: The ONNX graph containing the region's nodes + + Returns: + Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") + """ + # Import here to avoid circular dependency at runtime + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + return RegionPattern.from_region(self, graph).signature + + +# ============================================================================= +# Autotuner Q/DQ Insertion Specifications +# ============================================================================= + + +@dataclass +class InsertionScheme: + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + + **Scheme Identity:** + - Uniquely identified by the combination of insertion points (computed hash) + - latency_ms is a measured performance metric, not part of identity + - Two schemes with same insertion points but different latencies are considered identical + + **Application:** + - Node insertion points: Q/DQ at node inputs within the pattern + - Region insertion points: Q/DQ at child region boundaries (COMPOSITE only) + - All are resolved to actual configurations for each matching region + + **Performance Tracking:** + - latency_ms: Measured performance (inf = not yet measured) + - error: Whether this scheme encountered an error during measurement + - Used to select the best scheme for each pattern + + **Attributes:** + node_inputs: Q/DQ insertions at node inputs (list of NodeInputInsertionPoint) + child_region_inputs: Q/DQ insertions at child boundaries (list of ChildRegionInputInsertionPoint) + region_outputs: Q/DQ insertions at region outputs (list of RegionOutputInsertionPoint) + latency_ms: Measured latency in milliseconds (inf if not measured) + error: True if scheme measurement failed, False otherwise + profile_timestamp: ISO format timestamp when this scheme was profiled (None if not yet profiled) + """ + + node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) + child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) + region_outputs: list[RegionOutputInsertionPoint] = field(default_factory=list) + latency_ms: float = float("inf") + error: bool = False + profile_timestamp: str | None = None + + @property + def hash(self) -> str: + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + + **Hash Input:** + - Sorted node_inputs (for deterministic ordering) + - Sorted child_region_inputs (for deterministic ordering) + - Sorted region_outputs (for deterministic ordering) + - latency_ms is EXCLUDED (performance metric, not identity) + + **Use Cases:** + - Detect duplicate schemes before measurement + - Group schemes by configuration + - Efficient scheme comparison + + Returns: + 32-character hexadecimal string (SHA-256 truncated to 128 bits) + """ + # Sort points for deterministic hashing + sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) + sorted_regions = sorted( + [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] + ) + sorted_region_outputs = sorted( + [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] + ) + + # Create hash input string + hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" + + # Compute SHA-256 hash (128 bits) + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] + + @property + def is_empty(self) -> bool: + """Check if this is a baseline scheme with no Q/DQ insertions. + + Returns: + True if scheme has no node/region insertion points + """ + return ( + len(self.node_inputs) == 0 + and len(self.child_region_inputs) == 0 + and len(self.region_outputs) == 0 + ) + + @property + def has_error(self) -> bool: + """Check if this scheme encountered an error during measurement. + + Returns: + True if scheme has error=True, False otherwise + """ + return self.error + + @property + def is_profiled(self) -> bool: + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + + Returns: + True if scheme has been measured (latency_ms != inf) or has error, + False if scheme is waiting to be profiled (error=False and latency_ms=inf) + """ + return self.error or self.latency_ms != float("inf") + + @property + def num_node_insertions(self) -> int: + """Get count of node-level Q/DQ insertion points. + + Returns: + Number of NodeInputInsertionPoint entries + """ + return len(self.node_inputs) + + @property + def num_region_insertions(self) -> int: + """Get count of region-level Q/DQ insertion points. + + These specify Q/DQ insertions at child region boundaries within + COMPOSITE regions. + + Returns: + Number of ChildRegionInputInsertionPoint entries + """ + return len(self.child_region_inputs) + + @property + def num_region_output_insertions(self) -> int: + """Get count of region output insertion points. + + These specify Q/DQ insertions at outputs from child regions or nodes. + + Returns: + Number of RegionOutputInsertionPoint entries + """ + return len(self.region_outputs) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "latency_ms": self.latency_ms, + "error": self.error, + "profile_timestamp": self.profile_timestamp, + "nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs], + "child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs], + "region_outputs": [pt.to_dict() for pt in self.region_outputs], + "hash": self.hash, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": + """Create InsertionScheme from serialized dictionary. + + Reconstructs the insertion scheme from saved data, including node and + region insertion points. The hash is automatically recomputed from all + components to ensure consistency. + + Args: + data: Dictionary containing 'latency_ms', 'nodes_insertion_points', + 'child_region_inputs', and 'region_outputs' keys + + Returns: + Reconstructed InsertionScheme instance + """ + scheme = cls() + scheme.latency_ms = data.get("latency_ms", float("inf")) + scheme.error = data.get("error", False) + scheme.profile_timestamp = data.get("profile_timestamp") + + scheme.node_inputs = [ + NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", []) + ] + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint.from_dict(pt) + for pt in data.get("child_region_inputs", []) + ] + scheme.region_outputs = [ + RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ] + + # Note: hash is computed from points, so we don't load it from dict + # This ensures consistency even if stored hash differs + + return scheme + + def distance(self, other: "InsertionScheme") -> int: + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + **Distance Calculation:** + - Counts insertion points in self but not in other (need to be removed) + - Counts insertion points in other but not in self (need to be added) + - Considers all three types of insertion points: + * node_inputs + * child_region_inputs + * region_outputs + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + + Example: + >>> scheme1 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(1, 0), + ... ] + ... ) + >>> scheme2 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(2, 0), + ... ] + ... ) + >>> scheme1.distance(scheme2) # 2 (remove (1,0), add (2,0)) + 2 + """ + # Convert insertion points to sets for efficient set operations + self_nodes = set(self.node_inputs) + other_nodes = set(other.node_inputs) + + self_regions = set(self.child_region_inputs) + other_regions = set(other.child_region_inputs) + + self_region_outputs = set(self.region_outputs) + other_region_outputs = set(other.region_outputs) + + # Compute symmetric difference (elements in either set but not both) + # This gives us the total number of add + remove operations + node_distance = len(self_nodes.symmetric_difference(other_nodes)) + region_distance = len(self_regions.symmetric_difference(other_regions)) + region_output_distance = len(self_region_outputs.symmetric_difference(other_region_outputs)) + + return node_distance + region_distance + region_output_distance + + def __str__(self) -> str: + """String representation for debugging.""" + error_str = ", error=True" if self.error else "" + return ( + f"InsertionScheme(node_insertions={self.num_node_insertions}, " + f"region_insertions={self.num_region_insertions}, " + f"region_output_insertions={self.num_region_output_insertions}, " + f"latency={self.latency_ms:.3f}ms{error_str})" + ) + + +@dataclass +class PatternSchemes: + """Collection of Q/DQ insertion schemes for a single pattern. + + Manages multiple InsertionScheme candidates for a region pattern, tracking + their performance and identifying the best-performing configuration. This + enables pattern-based optimization where all regions with the same structure + use the same Q/DQ insertion strategy. + + **Workflow:** + 1. Pattern is identified from region structure + 2. Multiple schemes are generated and tested + 3. Each scheme is measured (latency_ms) + 4. Best scheme is selected (lowest latency) + 5. Best scheme is applied to all matching regions + + **Best Scheme Selection:** + - Automatically identifies scheme with lowest latency + - Excludes schemes with errors (error=True) + - Schemes with latency_ms = inf are considered unmeasured + - best_scheme property provides easy access to optimal configuration + + **Attributes:** + pattern: RegionPattern defining the structural signature + schemes: List of InsertionScheme candidates with measurements + """ + + pattern: Optional["RegionPattern"] = None # Structural pattern signature + schemes: list[InsertionScheme] = field(default_factory=list) # Candidate schemes + + @property + def pattern_signature(self) -> str: + """Get the pattern signature string.""" + return self.pattern.signature if self.pattern else "" + + @property + def pattern_size(self) -> int: + """Get the pattern size (total node count).""" + return self.pattern.size if self.pattern else 0 + + @property + def best_scheme_index(self) -> int: + """Get index of the best performing scheme (lowest latency). + + Scans all schemes to find the one with minimum latency_ms, + excluding schemes with errors. + If no schemes exist or all have errors, returns -1. + + Returns: + Index of best scheme (without errors), or -1 if no valid schemes available + """ + if len(self.schemes) == 0: + return -1 + min_idx, min_latency = -1, float("inf") + for idx, scheme in enumerate(self.schemes): + if not scheme.has_error and scheme.latency_ms < min_latency: + min_idx = idx + min_latency = scheme.latency_ms + return min_idx + + @property + def best_scheme(self) -> InsertionScheme | None: + """Get the best performing scheme (lowest latency). + + Convenience property for accessing the optimal scheme directly + without needing to look up by index. Excludes schemes with errors. + + Returns: + InsertionScheme with lowest latency (excluding error schemes), + or None if no valid schemes exist + """ + index = self.best_scheme_index + if index < 0 or index >= len(self.schemes): + return None + return self.schemes[index] + + @property + def num_schemes(self) -> int: + """Get total number of schemes.""" + return len(self.schemes) + + @property + def has_schemes(self) -> bool: + """Check if any schemes have been added.""" + return len(self.schemes) > 0 + + def add_scheme(self, scheme: InsertionScheme) -> None: + """Add a scheme to the collection. + + Args: + scheme: InsertionScheme to add + """ + self.schemes.append(scheme) + + def get_measured_schemes(self) -> list[InsertionScheme]: + """Get schemes that have been measured (finite latency). + + Returns: + List of schemes with performance measurements (excludes unmeasured schemes with inf latency) + """ + return [s for s in self.schemes if s.latency_ms != float("inf")] + + def get_valid_schemes(self) -> list[InsertionScheme]: + """Get schemes without errors. + + Returns: + List of schemes that completed successfully without errors + """ + return [s for s in self.schemes if not s.has_error] + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization. + + Note: Excludes runtime objects like pattern (RegionPattern). + Only serializes metadata and schemes. + """ + return { + "pattern_signature": self.pattern_signature, + "pattern_size": self.pattern_size, + "schemes": [scheme.to_dict() for scheme in self.schemes], + } + + @classmethod + def from_dict( + cls, data: dict[str, Any], pattern: Optional["RegionPattern"] = None + ) -> "PatternSchemes": + """Create PatternSchemes from serialized dictionary. + + Reconstructs the pattern schemes collection from saved data. The + RegionPattern object must be provided separately since it's not + serialized (it's a runtime object computed from the graph). + + If no pattern is provided, creates a minimal RegionPattern from the + saved signature and size for signature matching purposes. + + Args: + data: Dictionary containing 'pattern_signature', 'pattern_size', + and 'schemes' keys + pattern: RegionPattern object to associate (must match signature). + If None, creates minimal pattern from saved data. + + Returns: + Reconstructed PatternSchemes instance + """ + # Import here to avoid circular dependency at runtime + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + ps = cls() + + # If no pattern provided, create minimal one from saved data + if pattern is None and "pattern_signature" in data: + pattern = RegionPattern( + signature=data["pattern_signature"], size=data.get("pattern_size", 0) + ) + + ps.pattern = pattern + + ps.schemes = [ + InsertionScheme.from_dict(scheme_data) for scheme_data in data.get("schemes", []) + ] + + return ps + + def __str__(self) -> str: + """String representation for debugging.""" + best_latency = self.best_scheme.latency_ms if self.best_scheme else 0.0 + return ( + f"PatternSchemes(pattern='{self.pattern_signature[:40]}...', " + f"schemes={self.num_schemes}, best_latency={best_latency:.3f}ms)" + ) + + +@dataclass +class PatternCache: + """Pattern cache containing best-performing schemes for patterns with automatic eviction. + + Stores a collection of PatternSchemes that can be used as seeds for autotuning. + Each PatternSchemes contains high-performing insertion schemes for a specific + pattern signature. The cache automatically evicts non-performant schemes based on: + - Error status (schemes with errors are evicted) + - Duplicate schemes (only better-performing duplicate is kept) + - Similarity (similar schemes where only better-performing one is kept) + - Count limit (only top N best schemes are kept per pattern) + + **Seeded Autotuning:** + - Use previous autotuning results as starting points + - Skip redundant measurements for known patterns + - Transfer learned schemes across models or runs + + **Use Cases:** + - Load pattern cache from previous run to warm-start autotuning + - Share pattern cache data across similar models + - Store best-known schemes for common patterns + + **Workflow:** + 1. After autotuning, add schemes to PatternCache (non-performant entries auto-evicted) + 2. Serialize PatternCache to file (YAML) + 3. Load PatternCache in future runs as seeds + 4. Autotuner uses seeds to initialize pattern schemes + + **Attributes:** + pattern_schemes: List of PatternSchemes, one per pattern + minimum_distance: Minimum edit distance required between schemes in cache. + When adding new schemes, if a scheme is too similar (distance < minimum_distance) + to an existing scheme, only the better-performing one is kept (default: 4) + max_entries_per_pattern: Maximum number of schemes to keep per pattern. + Only the top N best-performing schemes are kept for each pattern. + Use 0 to keep all schemes (default: 32) + + Example: + >>> # Save pattern cache after autotuning + >>> cache = PatternCache(minimum_distance=4, max_entries_per_pattern=32) + >>> for schemes in autotuner.pattern_schemes.values(): + ... cache.add_pattern_schemes(schemes) # Auto-eviction happens here + >>> cache.save("pattern_cache.yaml") + >>> + >>> # Load pattern cache for next run + >>> cache = PatternCache.load("pattern_cache.yaml") + >>> autotuner.initialize(config, pattern_cache=cache) + """ + + pattern_schemes: list[PatternSchemes] = field(default_factory=list) + # Minimum distance between schemes in cache. + minimum_distance: int = 4 + # Maximum number of schemes per pattern. + max_entries_per_pattern: int = 32 + + def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: + """Add PatternSchemes to pattern cache with automatic eviction of non-performant entries. + + Merges new schemes with existing schemes for the same pattern, automatically + evicting schemes that are non-performant based on multiple criteria. + + **Automatic Eviction Strategy:** + + 1. **Error Eviction**: Schemes with errors are automatically excluded + + 2. **Duplicate Eviction**: When schemes have identical configurations (same hash), + only the one with better latency is kept + + 3. **Similarity Eviction**: When minimum_distance > 0, schemes that are too similar + to better-performing schemes are evicted + + 4. **Count Eviction**: When max_entries_per_pattern > 0, only the top N + best-performing schemes are kept per pattern + + Args: + pattern_schemes: PatternSchemes to add to the cache + """ + if not pattern_schemes or not pattern_schemes.pattern: + return + + pattern_sig = pattern_schemes.pattern_signature + + # Find existing PatternSchemes for this pattern + existing_idx = None + for idx, ps in enumerate(self.pattern_schemes): + if ps.pattern_signature == pattern_sig: + existing_idx = idx + break + + # Collect all schemes (existing + new) + all_schemes = list(pattern_schemes.schemes) + if existing_idx is not None: + all_schemes.extend(self.pattern_schemes[existing_idx].schemes) + + # Filter out schemes with errors and deduplicate by hash + valid_schemes = [s for s in all_schemes if not s.has_error] + unique_schemes = {} + for scheme in valid_schemes: + scheme_hash = scheme.hash + if ( + scheme_hash not in unique_schemes + or scheme.latency_ms < unique_schemes[scheme_hash].latency_ms + ): + unique_schemes[scheme_hash] = scheme + + # Sort by latency to get best schemes + sorted_schemes = sorted(unique_schemes.values(), key=lambda s: s.latency_ms) + + # Apply distance-based filtering if minimum_distance > 0 + if self.minimum_distance > 0: + filtered_schemes = [] + for scheme in sorted_schemes: + # Check if this scheme is too similar to any already-filtered scheme + too_similar = False + for existing_scheme in filtered_schemes: + distance = scheme.distance(existing_scheme) + if distance < self.minimum_distance: + # Schemes are too similar, keep the better one + if scheme.latency_ms < existing_scheme.latency_ms: + # New scheme is better, remove existing and add new + filtered_schemes.remove(existing_scheme) + break + else: + # Existing scheme is better, skip new one + too_similar = True + break + + if not too_similar: + filtered_schemes.append(scheme) + + sorted_schemes = filtered_schemes + + # Apply count limit if max_entries_per_pattern > 0 + # Keep only the top N best-performing schemes per pattern + if self.max_entries_per_pattern > 0: + sorted_schemes = sorted_schemes[: self.max_entries_per_pattern] + + # Create PatternSchemes with all schemes that passed the eviction criteria + result = PatternSchemes(pattern=pattern_schemes.pattern) + result.schemes = sorted_schemes + + # Replace existing or append new + if existing_idx is not None: + self.pattern_schemes[existing_idx] = result + else: + self.pattern_schemes.append(result) + + def get_pattern_schemes(self, pattern_signature: str) -> PatternSchemes | None: + """Get PatternSchemes for a specific pattern signature. + + Args: + pattern_signature: Pattern signature to lookup + + Returns: + PatternSchemes if found, None otherwise + """ + for ps in self.pattern_schemes: + if ps.pattern_signature == pattern_signature: + return ps + return None + + def has_pattern(self, pattern_signature: str) -> bool: + """Check if pattern cache contains a specific pattern. + + Args: + pattern_signature: Pattern signature to check + + Returns: + True if pattern exists in pattern cache + """ + return any(ps.pattern_signature == pattern_signature for ps in self.pattern_schemes) + + def add_pattern_from_region( + self, region: Region, graph: gs.Graph, quantized_tensors: set[str] + ) -> None: + """Build and add a pattern cache entry from a region in a quantized model. + + Analyzes a region from an already-quantized model to extract its Q/DQ + insertion scheme. This allows capturing known-good quantization strategies + from existing models and using them as seeds for autotuning. + + **Workflow:** + 1. Create RegionPattern from the region structure + 2. Identify which tensors in the region are quantized + 3. Map quantized tensors to pattern-relative insertion points: + - Node input tensors → NodeInputInsertionPoint + - Child region input tensors → ChildRegionInputInsertionPoint + - Region output tensors → RegionOutputInsertionPoint + 4. Create InsertionScheme with identified insertion points + 5. Add to pattern cache (or merge if pattern already exists) + + Args: + region: Region from the quantized model to analyze + graph: ONNX graph containing the region + quantized_tensors: Set of tensor names that have Q/DQ nodes + + Example: + >>> cache = PatternCache() + >>> for region in all_regions: + ... cache.add_pattern_from_region(region, graph, quantized_tensors) + >>> cache.save("learned_patterns.yaml") + """ + # Import here to avoid circular dependency at runtime + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + # Create pattern from region + pattern = RegionPattern.from_region(region, graph) + # Track insertion points + scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + # Analyze node inputs + full_insertion_scheme = pattern.get_full_insertion_scheme(region, graph) + for point in full_insertion_scheme.node_inputs: + temp_scheme = InsertionScheme( + node_inputs=[point], + child_region_inputs=[], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_ips: list[ResolvedInsertionPoint] = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.node_inputs.append(point) + # Analyze region boundaries (for COMPOSITE regions) + if region.type == RegionType.COMPOSITE: + for child_point in full_insertion_scheme.child_region_inputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[child_point], + region_outputs=[], + latency_ms=float("inf"), + error=False, + ) + temp_ips = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.child_region_inputs.append(child_point) + # Analyze region outputs + for output_point in full_insertion_scheme.region_outputs: + temp_scheme = InsertionScheme( + node_inputs=[], + child_region_inputs=[], + region_outputs=[output_point], + latency_ms=float("inf"), + error=False, + ) + temp_ips = pattern.matches(region, graph, temp_scheme) + temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} + if len(temp_tensor_names.intersection(quantized_tensors)) > 0: + scheme.region_outputs.append(output_point) + # Add pattern and scheme to pattern cache + pattern_schemes = PatternSchemes(pattern=pattern, schemes=[scheme]) + self.add_pattern_schemes(pattern_schemes) + num_points = ( + len(scheme.node_inputs) + len(scheme.child_region_inputs) + len(scheme.region_outputs) + ) + logger.debug( + f"Added pattern from region {region.get_id()} with {num_points} insertion points" + ) + # Add patterns from child regions + if region.type == RegionType.COMPOSITE: + for child_region in region.get_children(): + self.add_pattern_from_region(child_region, graph, quantized_tensors) + + @property + def num_patterns(self) -> int: + """Get number of patterns in pattern cache.""" + return len(self.pattern_schemes) + + @property + def total_schemes(self) -> int: + """Get total number of schemes across all patterns.""" + return sum(ps.num_schemes for ps in self.pattern_schemes) + + def get_all_pattern_signatures(self) -> list[str]: + """Get list of all pattern signatures in pattern cache. + + Returns: + List of pattern signature strings + """ + return [ps.pattern_signature for ps in self.pattern_schemes] + + def clear(self) -> None: + """Clear all pattern cache data.""" + self.pattern_schemes.clear() + + def merge(self, other: "PatternCache", prefer_existing: bool = True) -> None: + """Merge another PatternCache into this one. + + Args: + other: PatternCache to merge + prefer_existing: If True, keep existing patterns when there's a conflict. + If False, overwrite with other's patterns. + """ + for schemes in other.pattern_schemes: + if not self.has_pattern(schemes.pattern_signature) or not prefer_existing: + self.add_pattern_schemes(schemes) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization. + + Returns: + Dictionary with 'minimum_distance', 'max_entries_per_pattern', and 'pattern_schemes' keys + """ + return { + "minimum_distance": self.minimum_distance, + "max_entries_per_pattern": self.max_entries_per_pattern, + "pattern_schemes": [ps.to_dict() for ps in self.pattern_schemes], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "PatternCache": + """Create PatternCache from serialized dictionary. + + Note: RegionPattern objects are not restored (they're runtime objects). + Only pattern signatures and scheme data are loaded. + + Args: + data: Dictionary containing pattern cache data + + Returns: + Reconstructed PatternCache instance + """ + cache = cls( + minimum_distance=data.get("minimum_distance", 4), + max_entries_per_pattern=data.get("max_entries_per_pattern", 32), + ) + + for ps_data in data.get("pattern_schemes", []): + # Create PatternSchemes without pattern object (pattern=None) + ps = PatternSchemes.from_dict(ps_data, pattern=None) + cache.pattern_schemes.append(ps) + + return cache + + def save(self, output_path: str) -> None: + """Save pattern cache to a YAML file. + + Serializes all pattern schemes and their insertion points to a YAML file + that can be loaded later for seeded autotuning. The format matches the + autotuner state file format for consistency. + + **Contents:** + - minimum_distance: Minimum distance between schemes + - max_entries_per_pattern: Maximum number of schemes per pattern + - pattern_schemes: List of all PatternSchemes with their insertion points + + Args: + output_path: File path where the YAML pattern cache file will be written + + Example: + >>> cache = PatternCache(minimum_distance=1, max_entries_per_pattern=16) + >>> for schemes in autotuner.pattern_schemes.values(): + ... cache.add_pattern_schemes(schemes) + >>> cache.save("pattern_cache.yaml") + """ + state = self.to_dict() + + with open(output_path, "w") as f: + yaml.dump(state, f, default_flow_style=False, sort_keys=False) + + logger.info( + f"Saved pattern cache → {output_path} ({self.num_patterns} patterns, " + f"{self.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={self.minimum_distance}, " + f"max_per_pattern={self.max_entries_per_pattern}" + ) + + @classmethod + def load(cls, input_path: str) -> "PatternCache": + """Load pattern cache from a YAML file. + + Reads a previously saved pattern cache file and reconstructs all pattern + schemes. The loaded pattern cache can be used to seed autotuning with + known-good insertion schemes. + + **Note:** RegionPattern objects are not restored since they depend on + the actual model structure. Only pattern signatures and scheme data + are loaded. + + Args: + input_path: File path to the YAML pattern cache file to load + + Returns: + PatternCache instance with all pattern schemes loaded + + Raises: + FileNotFoundError: If the input_path doesn't exist + + Example: + >>> cache = PatternCache.load("pattern_cache.yaml") + >>> autotuner.initialize(config, pattern_cache=cache) + """ + with open(input_path) as f: + state = yaml.safe_load(f) + + cache = cls.from_dict(state) + + logger.info( + f"Loaded pattern cache from {input_path} ({cache.num_patterns} patterns, " + f"{cache.total_schemes} schemes)" + ) + logger.debug( + f"Cache settings: min_distance={cache.minimum_distance}, " + f"max_per_pattern={cache.max_entries_per_pattern}" + ) + + return cache + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"PatternCache(patterns={self.num_patterns}, " + f"schemes={self.total_schemes}, " + f"minimum_distance={self.minimum_distance}, " + f"max_entries_per_pattern={self.max_entries_per_pattern})" + ) + + +@dataclass +class Config: + """Configuration parameters for QDQ autotuning. + + Controls the autotuning process including performance requirements, quantization + parameters, region building, scheme generation, and finetuning behavior. + + Attributes: + # Logging + verbose: Enable detailed logging of autotuning progress (default: False) + + # Quantization Parameters + default_q_scale: Default scale parameter for Q/DQ nodes. Controls quantization + granularity. Typical range: 0.01-0.1 (default: 0.1) + default_q_zero_point: Default zero-point for Q/DQ nodes. Use 0 for signed int8, + 128 for unsigned uint8 (default: 0) + default_quant_type: Quantization type for Q/DQ nodes. Options: "int8" (default), "fp8" + + # Region Builder Settings + maximum_sequence_region_size: Maximum number of nodes in a sequence region during + top-down refinement. Prevents overly large merged regions (default: 10) + minimum_topdown_search_size: Minimum number of nodes in a region to trigger + top-down search during region building (default: 10) + + # Scheme Generation Settings + top_percent_to_mutate: Top percentage of best schemes to use as mutation seeds + during scheme generation. Range: 0.0-1.0 (default: 0.1 = top 10%) + minimum_schemes_to_mutate: Minimum number of schemes to keep as mutation seeds, + even if top_percent_to_mutate results in fewer (default: 10) + maximum_mutations: Maximum number of mutations to apply to a single scheme + during generation (default: 3) + maximum_generation_attempts: Maximum attempts to generate a unique new scheme + before giving up (default: 100) + + # Pattern Cache Settings + pattern_cache_minimum_distance: Minimum edit distance required between schemes in cache. + When adding schemes, if a scheme is too similar (distance < minimum_distance) + to an existing scheme, only the better-performing one is kept (default: 4) + pattern_cache_max_entries_per_pattern: Maximum number of schemes to keep per pattern + in pattern cache. Only the top N best-performing schemes are kept for each pattern. + Use 0 to keep all schemes (default: 32) + + Example: + >>> config = Config( + ... verbose=True, # Enable detailed logging + ... top_percent_to_mutate=0.2, # Use top 20% schemes as seeds + ... pattern_cache_minimum_distance=2, # Require more diversity in cache + ... ) + >>> autotuner = QDQAutotuner(model) + >>> autotuner.initialize(config) + """ + + # Logging + verbose: bool = False + + # Quantization Parameters + default_q_scale: float = 0.1 + default_q_zero_point: int = 0 + default_quant_type: str = "int8" + default_dq_dtype: str = "float32" + + # Region Builder Settings + maximum_sequence_region_size: int = 10 + minimum_topdown_search_size: int = 10 + + # Scheme Generation Settings + top_percent_to_mutate: float = 0.1 + minimum_schemes_to_mutate: int = 10 + maximum_mutations: int = 3 + maximum_generation_attempts: int = 100 + + # Pattern Cache Settings + pattern_cache_minimum_distance: int = 4 + pattern_cache_max_entries_per_pattern: int = 32 diff --git a/modelopt/onnx/quantization/autotune/region_inspect.py b/modelopt/onnx/quantization/autotune/region_inspect.py new file mode 100644 index 000000000..32b7cc58a --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_inspect.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region search inspection tool for ONNX models.""" + +import argparse +import logging +import sys +from collections import Counter + +import onnx +import onnx_graphsurgeon as gs + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.insertion_points import has_quantizable_operations +from modelopt.onnx.quantization.autotune.region_search import ( + DEFAULT_MAX_STEPS, + CombinedRegionSearch, +) + + +def inspect_region_search( + onnx_path: str, + max_sequence_size: int = 10, + include_all_regions: bool = False, +) -> list[Region]: + """Inspect region search results for an ONNX model. + + This function loads an ONNX model, runs CombinedRegionSearch (which performs + both bottom-up partitioning and top-down refinement internally), and prints + detailed information about the discovered regions including their hierarchical + structure. + + **What it does:** + 1. Loads ONNX model and converts to GraphSurgeon format + 2. Creates CombinedRegionSearch instance with specified parameters + 3. Runs two-phase search (partitioning + refinement) via search() + 4. Displays detailed region structure and statistics + 5. Returns the final list of refined regions + + **Output Sections:** + - Initialization: Shows search parameters + - Two-Phase Search: Runs automatically via CombinedRegionSearch.search() + - Detailed Structure: Shows each region's hierarchy and properties + - Summary Statistics: Shows region counts and node coverage + + Args: + onnx_path: Path to the ONNX model file + max_sequence_size: Maximum size for sequence regions during refinement (default: 10) + include_all_regions: Include all regions, even those without major quantizable + operations (Conv, MatMul, etc.). Default: False (skips such regions) + + Returns: + List of discovered and refined regions (LEAF and COMPOSITE) + """ + # Load ONNX model + logger.info(f"Loading model: {onnx_path}") + onnx_model = onnx.load(onnx_path) + # Convert to onnx_graphsurgeon Graph + graph = gs.import_onnx(onnx_model) + graph.cleanup().toposort() + logger.info( + f"Loaded graph: {len(graph.nodes)} nodes, {len(graph.inputs)} inputs, {len(graph.outputs)} outputs" + ) + # Initialize CombinedRegionSearch (contains RegionPartitioner internally) + logger.debug( + f"Search parameters: max_steps={DEFAULT_MAX_STEPS}, max_sequence_size={max_sequence_size}" + ) + + combined_search = CombinedRegionSearch(graph, maximum_sequence_region_size=max_sequence_size) + + # Run complete two-phase region search + logger.info("Running region search") + regions = combined_search.search_regions() + # Show detailed region structure + logger.info("Analyzing region structure") + all_regions = [] + for i, region in enumerate(regions): + for child in region.get_children(): + if not include_all_regions and not has_quantizable_operations(child, graph): + region.remove_child(child) + if not include_all_regions and not has_quantizable_operations(region, graph): + logger.debug(f"Filtered out region {i} (no quantizable operations)") + continue + logger.debug( + f"Region {i}: {region.type.value}, {len(region.get_region_nodes_and_descendants())} nodes, " + f"{len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + all_regions.append(region) + if region.type == RegionType.COMPOSITE: + logger.debug(f" {len(region.get_children())} child regions") + all_regions.extend(region.get_children()) + combined_search.print_tree(region, indent=2) + + # Summary statistics + type_counts = Counter(r.type for r in all_regions) + leaf_regions, composite_regions = ( + type_counts[RegionType.LEAF], + type_counts[RegionType.COMPOSITE], + ) + + all_nodes = {n for r in all_regions for n in r.get_region_nodes_and_descendants()} + total_nodes = len(all_nodes) + coverage_pct = 100 * total_nodes / len(graph.nodes) if graph.nodes else 0 + + logger.info( + f"Summary: {len(all_regions)} regions ({leaf_regions} LEAF, {composite_regions} COMPOSITE), " + f"{total_nodes}/{len(graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + # Print histogram of region sizes + region_sizes = [ + len(r.get_region_nodes_and_descendants()) for r in all_regions if r.type == RegionType.LEAF + ] + + if region_sizes: + min_size = min(region_sizes) + max_size = max(region_sizes) + avg_size = sum(region_sizes) / len(region_sizes) + + logger.info(f"LEAF region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + size_counts = Counter(region_sizes) + logger.debug("Size distribution:") + for size in sorted(size_counts.keys()): + count = size_counts[size] + bar = "█" * min(count, 50) + logger.debug(f" {size:4d} nodes: {bar} ({count} regions)") + + return regions + + +def main(): + """Command-line entry point for region search inspection.""" + parser = argparse.ArgumentParser( + prog="modelopt.onnx.quantization.autotune.region_inspect", + description="Inspect region search results for ONNX models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic inspection + python -m modelopt.onnx.quantization.autotune.region_inspect --model model.onnx + + # Verbose mode for debug logging + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --verbose + + # Custom maximum sequence size + python -m modelopt.onnx.quantization.autotune.region_inspect \\ + --model model.onnx --max-sequence-size 20 + """, + ) + + parser.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file") + parser.add_argument( + "--max-sequence-size", + type=int, + default=10, + help="Maximum size for sequence regions during refinement (default: 10)", + ) + parser.add_argument( + "--include-all-regions", + action="store_true", + help="Include all regions, even those without major quantizable operations. " + "Default: False (skips such regions)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug logging") + + args = parser.parse_args() + + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + logger.setLevel(log_level) + + try: + regions = inspect_region_search( + onnx_path=args.model, + max_sequence_size=args.max_sequence_size, + include_all_regions=args.include_all_regions, + ) + logger.info(f"✓ Inspection complete: {len(regions)} top-level regions discovered") + return 0 + except Exception as e: + logger.error(f"Inspection failed: {e}", exc_info=args.verbose) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/modelopt/onnx/quantization/autotune/region_pattern.py b/modelopt/onnx/quantization/autotune/region_pattern.py new file mode 100644 index 000000000..9abd42fd4 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_pattern.py @@ -0,0 +1,669 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region Pattern Signature Generator. + +Provides structural pattern analysis for regions in ONNX computation graphs. +This module enables: +- Pattern-based region grouping by structural similarity +- Deterministic signature generation for pattern matching +- Resolution of insertion points to actual tensor names +- Support for both node-level and region-level Q/DQ insertion + +Key concepts: +- NodeInputInsertionPoint: Specifies Q/DQ insertion at a node's input +- ChildRegionInputInsertionPoint: Specifies Q/DQ insertion at a child region's input boundary +- RegionOutputInsertionPoint: Specifies Q/DQ insertion at a region output (child or node) +- Pattern matching: Groups regions with identical structure for shared optimization +""" + +import hashlib +import logging +from typing import Union + +import onnx_graphsurgeon as gs + +from modelopt.onnx.quantization.autotune.common import InsertionScheme, Region +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, + ResolvedInsertionPoint, +) + +# Module logger +logger = logging.getLogger(__name__) + +# Commutative/symmetric operations where operand order doesn't matter +SYMMETRIC_OPERATIONS = { + "Add", + "Mul", + "And", + "Or", + "Xor", + "Equal", + "Max", + "Min", + "Sum", + "Mean", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", +} + + +class RegionPattern: + """Represents a structural pattern of a region. + + The pattern captures the topology and operation types in a region, + enabling pattern matching and region comparison. Patterns are hashable + and can be used as dictionary keys for efficient grouping and lookup. + + Two RegionPattern objects are considered equal if they have the same + signature string, regardless of their size (which represents instance-specific + node count). + + Attributes: + signature: The unique signature string identifying the pattern + size: Total node count for this pattern instance + """ + + # ========================================================================= + # Initialization + # ========================================================================= + + def __init__(self, signature: str, size: int): + """Initialize a region pattern. + + Args: + signature: The signature string representing the pattern structure + size: Total size (node count) of the region + """ + self.signature = signature + self.size = size + + # ========================================================================= + # Properties + # ========================================================================= + + @property + def is_empty(self) -> bool: + """Check if pattern represents an empty region.""" + return self.signature == "EMPTY" or self.size == 0 + + @property + def is_composite(self) -> bool: + """Check if pattern represents a composite region.""" + return self.signature.startswith("COMPOSITE(") + + @property + def is_leaf(self) -> bool: + """Check if pattern represents a leaf region (no composite structure).""" + return not self.is_composite and not self.is_empty + + # ========================================================================= + # Special Methods (Python Protocol) + # ========================================================================= + + def __str__(self) -> str: + """String representation showing just the signature.""" + return self.signature + + def __repr__(self) -> str: + """Developer-friendly representation with signature and size.""" + return f"RegionPattern('{self.signature}', size={self.size})" + + def __eq__(self, other) -> bool: + """Check equality based on signature only.""" + if not isinstance(other, RegionPattern): + return False + return self.signature == other.signature + + def __hash__(self) -> int: + """Hash based on signature for use as dict key.""" + return hash(self.signature) + + # ========================================================================= + # Public Query Methods + # ========================================================================= + + def get_hash(self) -> str: + """Get a 128-bit cryptographic hash of the pattern signature. + + Uses SHA-256 (truncated to 128 bits) to generate a compact, deterministic + hash for efficient pattern comparison and storage. This hash is more + compact than the full signature for storage and comparison purposes. + + Returns: + Hexadecimal string representation of the hash (32 characters) + + Example: + >>> pattern = RegionPattern.from_region(region, graph) + >>> hash_val = pattern.get_hash() # Returns 32 hex characters + >>> print(f"Pattern hash: {hash_val}") + """ + # SHA-256 truncated to 128 bits = 32 hex characters + return hashlib.sha256(self.signature.encode("utf-8")).hexdigest()[:32] + + def get_short_signature(self, max_length: int = 80) -> str: + """Get a truncated version of the signature for display purposes. + + Args: + max_length: Maximum length of the returned string (default: 80) + + Returns: + Truncated signature with '...' suffix if needed + """ + if len(self.signature) <= max_length: + return self.signature + return self.signature[: max_length - 3] + "..." + + # ========================================================================= + # Public Pattern Matching and Construction + # ========================================================================= + + @classmethod + def from_region(cls, region: Region, graph: gs.Graph) -> "RegionPattern": + """Compute a structural pattern for a region. + + The pattern captures: + - Direct node operations in the region + - Structure of sub-regions (recursively) + - Handles symmetric operations consistently + - Sorts sub-regions by size for determinism + + Args: + region: The region to compute pattern for + graph: The ONNX graph containing the nodes + + Returns: + RegionPattern object containing the signature and metadata + """ + signature_str = cls._compute_signature_recursive(region, graph) + total_size = region.get_total_size() + return cls(signature_str, total_size) + + def matches( + self, + other: Union["RegionPattern", Region], + graph: gs.Graph | None = None, + scheme: InsertionScheme | None = None, + ) -> bool | list[int] | set[ResolvedInsertionPoint] | None: + """Check if this pattern matches another pattern or region. + + This method provides three distinct behaviors depending on the arguments: + + 1. **Pattern-to-pattern comparison** (other is RegionPattern, scheme is None): + Returns bool indicating structural equivalence. + + 2. **Pattern-to-region matching** (other is Region, scheme is None): + Returns list of node IDs in pattern order if match succeeds, None otherwise. + + 3. **Pattern-to-region with insertion scheme** (other is Region, scheme provided): + Returns set of resolved insertion points where Q/DQ should be inserted, considering: + - NodeInputInsertionPoints from the scheme (node-level Q/DQ) + - ChildRegionInputInsertionPoints from the scheme (child region input Q/DQ) + - RegionOutputInsertionPoints from the scheme (region output Q/DQ) + Returns empty set if pattern doesn't match. + + Args: + other: Either a RegionPattern or Region to compare with + graph: Required when other is a Region (for computing its pattern) + scheme: Optional InsertionScheme containing node_inputs, + child_region_inputs, and region_outputs + to resolve to tensor names + + Returns: + - bool: If other is RegionPattern, True if patterns match + - List[int]: If other is Region and scheme is None, list of node IDs + in pattern order (None if no match) + - Set[ResolvedInsertionPoint]: If other is Region and scheme is provided, + set of resolved insertion points for Q/DQ insertion (empty set if no match) + + Raises: + ValueError: If other is Region but graph is not provided, or if scheme + is provided but other is not a Region + TypeError: If other is neither RegionPattern nor Region + """ + if isinstance(other, RegionPattern): + # Behavior 1: Pattern-to-pattern comparison + if scheme is not None: + raise ValueError("scheme parameter can only be used when matching against a Region") + return self._matches_pattern(other) + elif isinstance(other, Region) and scheme is None: + # Behavior 2: Pattern-to-region matching (returns node IDs) + return self._matches_region(other, graph) + elif isinstance(other, Region) and scheme is not None: + if graph is None: + raise ValueError("graph parameter is required") + # Verify the region matches this pattern + region_pattern = RegionPattern.from_region(other, graph) + if self != region_pattern: + return set() + + resolved_ips = set() + # Resolve NodeInputInsertionPoints to tensor names + for ip in scheme.node_inputs: + resolved_ips.update(ip.resolve(other, graph)) + # Resolve ChildRegionInputInsertionPoints to tensor names + for ip in scheme.child_region_inputs: + resolved_ips.update(ip.resolve(other, graph)) + # Resolve RegionOutputInsertionPoints to tensor names + for ip in scheme.region_outputs: + resolved_ips.update(ip.resolve(other, graph)) + return resolved_ips + else: + raise TypeError(f"Expected RegionPattern or Region, got {type(other).__name__}") + + # ========================================================================= + # Private Pattern Matching Helpers + # ========================================================================= + + def _matches_pattern(self, other: "RegionPattern") -> bool: + """Internal function: Match this pattern against another pattern. + + Args: + other: Another RegionPattern to compare with + + Returns: + True if patterns are structurally equivalent, False otherwise + """ + return self == other + + def _matches_region(self, region: Region, graph: gs.Graph | None) -> list[int] | None: + """Internal function: Match this pattern against a region. + + Args: + region: The region to match against + graph: The ONNX graph containing the nodes + + Returns: + List of node IDs in match order if pattern matches, None otherwise. + Match order follows the pattern computation order: + - Direct nodes of the region (sorted) + - Then recursively, nodes from child regions (in child sort order) + + Raises: + ValueError: If graph is not provided + """ + if graph is None: + raise ValueError("graph parameter is required when matching against a Region") + + # Compute pattern for the region + region_pattern = RegionPattern.from_region(region, graph) + + # Check if patterns match + if self == region_pattern: + # Return node IDs in match order (same as signature computation order) + return self._collect_nodes_in_match_order(region) + else: + return None + + def get_full_insertion_scheme(self, region: Region, graph: gs.Graph) -> InsertionScheme: + """Get all possible insertion points for a region in a single InsertionScheme. + + This method first verifies that the region matches this pattern (raises if not). + It then collects all three types of insertion points: + 1. Node input insertion points (Q/DQ at node inputs within the region) + 2. Child region input insertion points (Q/DQ at child region input boundaries) + 3. Region output insertion points (Q/DQ at region output boundaries) + + The returned InsertionScheme contains all possible Q/DQ insertion + locations for this region pattern. This can be used as: + - A baseline scheme with all possible insertions + - A starting point for optimization algorithms + - A comprehensive view of all insertion opportunities + + Important: Pattern-relative indices in the returned scheme are based on + sorted child/node ordering. The sorting order (-level, size) MUST match + insertion_points.py for correct resolution. + + Note: The returned scheme has no child region schemes specified, + latency is set to infinity (unmeasured), and error flag is False. + + Args: + region: The region to analyze + graph: The ONNX graph containing the nodes + + Returns: + InsertionScheme containing all possible insertion points for this region + + Raises: + AssertionError: If the region doesn't match this pattern + """ + # Verify that the region matches this pattern + region_pattern = RegionPattern.from_region(region, graph) + assert self == region_pattern, "Region pattern mismatch" + + scheme = InsertionScheme() + # Collect all node input insertion points + scheme.node_inputs = NodeInputInsertionPoint.collect_from_region(region, graph) + # Collect all child region input insertion points (at child boundaries) + scheme.child_region_inputs = ChildRegionInputInsertionPoint.collect_from_region( + region, graph + ) + # Collect all region output insertion points + scheme.region_outputs = RegionOutputInsertionPoint.collect_from_region(region, graph) + + return scheme + + def format_tree(self, region: Region, graph: gs.Graph, indent: int = 0) -> str: + """Format this pattern and region as a human-readable tree. + + Useful for debugging and visualization. + + Args: + region: The region associated with this pattern + graph: The ONNX graph + indent: Indentation level + + Returns: + Formatted string representation + """ + prefix = " " * indent + result = f"{prefix}Region {region.get_id()}: {self.signature} (size={self.size})\n" + + for child in region.get_children(): + child_pattern = RegionPattern.from_region(child, graph) + result += child_pattern.format_tree(child, graph, indent + 1) + + return result + + # ========================================================================= + # Static Utility Methods + # ========================================================================= + + @staticmethod + def _collect_nodes_in_match_order(region: Region) -> list[int]: + """Collect node IDs in the same order as signature computation. + + This follows the traversal order used by _compute_signature_recursive: + 1. Direct nodes of the region (sorted by node index) + 2. Recursively, nodes from child regions (children sorted by -level, then size) + + The child sorting order MUST match _compute_signature_recursive and + insertion_points.py for correct pattern-relative index alignment. + + Args: + region: The region to collect nodes from + + Returns: + List of node IDs in match order + """ + node_ids = [] + + # Add direct nodes of this region (sorted) + node_ids.extend(sorted(region.get_nodes())) + + # Get children and sort them the same way as signature computation + # CRITICAL: This sorting must match _compute_signature_recursive and insertion_points.py + # Sort by: 1) level (descending - higher level first), 2) size (ascending) + children = region.get_children() + sorted_children = sorted(children, key=lambda r: (-r.get_level(), r.get_total_size())) + + # Recursively collect nodes from children in order + for child in sorted_children: + node_ids.extend(RegionPattern._collect_nodes_in_match_order(child)) + + return node_ids + + # --- Signature Computation --- + + @staticmethod + def _compute_signature_recursive(region: Region, graph: gs.Graph) -> str: + """Recursively compute structural signature for a region. + + The signature captures: + - Node operations and their key parameters (for LEAF regions) + - Hierarchical structure with child patterns (for COMPOSITE regions) + - Deterministic ordering (sorted nodes and children) + - Normalized handling of symmetric/commutative operations + + Signature formats: + - Empty region: "EMPTY" + - Leaf region: "Op1->Op2->Op3" or "Op1[params]->Op2[params]" + - Composite with nodes: "COMPOSITE(nodes|child1+child2)" + - Composite without nodes: "COMPOSITE(child1+child2)" + + Child Sorting: + - Children are sorted by (-level, size) for deterministic signatures + - This order MUST match insertion_points.py for correct pattern-relative indexing + - Higher-level (more abstract) children come first + - Within same level, smaller children come first + + Args: + region: The region to process + graph: The ONNX graph containing the nodes + + Returns: + Deterministic signature string representing the region structure + """ + # Collect direct node operations in this region + node_ops = [] + nodes_list = list(graph.nodes) + node_indices_set = region.get_nodes() + + for node_idx in sorted(node_indices_set): + if node_idx < len(nodes_list): + node = nodes_list[node_idx] + # Include operation type and key parameters + # Pass region node indices for symmetric operation handling + node_sig = RegionPattern._make_node_with_params_signature( + node, graph, node_indices_set + ) + node_ops.append(node_sig) + + # Get child regions + children = region.get_children() + + if not children and not node_ops: + # Empty region (edge case) + return "EMPTY" + + if not children: + # LEAF region - only direct nodes, no hierarchical structure + return RegionPattern._make_node_signature(node_ops) + + # COMPOSITE region - has hierarchical structure with children + # Sort children deterministically for consistent signatures + # CRITICAL: This sorting must match insertion_points.py for pattern-relative index alignment + # Sort by: 1) level (descending - higher level first), 2) size (ascending) + sorted_children = sorted(children, key=lambda r: (-r.get_level(), r.get_total_size())) + + # Recursively compute child signatures + child_signatures = [] + for child in sorted_children: + child_sig = RegionPattern._compute_signature_recursive(child, graph) + child_signatures.append(child_sig) + + # Combine node operations and child signatures + if node_ops: + # Has both direct nodes and hierarchical children + node_sig = RegionPattern._make_node_signature(node_ops) + return f"COMPOSITE({node_sig}|{RegionPattern._join_signatures(child_signatures)})" + else: + # Only children, no direct nodes in this region + return f"COMPOSITE({RegionPattern._join_signatures(child_signatures)})" + + @staticmethod + def _make_node_with_params_signature( + node: gs.Node, graph: gs.Graph, region_node_indices: set + ) -> str: + """Create signature for a single node including its parameters. + + Includes operation type and key attributes that affect behavior. + For symmetric/commutative operations (Add, Mul, etc.), normalizes + input order to ensure consistent signatures regardless of operand order. + Ensures deterministic ordering by sorting attributes by key name. + + Args: + node: The ONNX node + graph: The ONNX graph containing all nodes + region_node_indices: Set of node indices in the current region + + Returns: + Signature string examples: + - "Relu" - Simple operation without attributes + - "Conv[dilations=1x1,kernel_shape=3x3]" - Operation with attributes + - "Add" - Symmetric op with sorted input sources + - "Mul[axis=1]" - Symmetric op with both + """ + op = node.op + + # Handle symmetric operations - normalize input order + if op in SYMMETRIC_OPERATIONS and len(node.inputs) > 1: + # Get input source information for normalization + input_sources = [] + nodes_list = list(graph.nodes) + + # Build node index lookup for efficient producer finding + node_to_idx = {id(n): idx for idx, n in enumerate(nodes_list)} + + for inp in node.inputs: + if inp is None or not hasattr(inp, "inputs") or not inp.inputs: + # Input from graph input or constant + input_sources.append(("external", "input-or-constant")) + else: + # Input from another node's output + producer_node = inp.inputs[0] if inp.inputs else None + if producer_node and id(producer_node) in node_to_idx: + producer_idx = node_to_idx[id(producer_node)] + # Check if producer is in the same region + if producer_idx in region_node_indices: + # Use relative position: 'internal' + producer op type + input_sources.append(("internal", producer_node.op)) + else: + # Producer outside region + input_sources.append(("external", producer_node.op)) + else: + # Unknown producer + input_sources.append(("external", "unknown")) + + # Sort input sources for deterministic ordering + # This ensures Add(A,B) and Add(B,A) have the same signature + sorted_sources = sorted(input_sources) + + # Create source signature + source_sig = ",".join(f"{src[0]}:{src[1]}" for src in sorted_sources) + + # If node has no attributes, return op with input signature + if not node.attrs: + return f"{op}<{source_sig}>" + + # Otherwise, will add input signature after attributes + has_symmetric_inputs = True + else: + has_symmetric_inputs = False + + # Handle non-symmetric operations or symmetric ops without multiple inputs + if not node.attrs and not has_symmetric_inputs: + return op + + # Extract and format key attributes (only if node has attributes) + if node.attrs: + # Sort attributes alphabetically for deterministic ordering + attr_parts = [] + for key in sorted(node.attrs.keys()): + value = node.attrs[key] + + # Format different attribute types deterministically + if isinstance(value, (list, tuple)): + # Format lists/tuples compactly + # Use 'x' separator for numeric arrays (common in ONNX) + if len(value) > 0 and all(isinstance(v, (int, float)) for v in value): + # Format each element consistently + if all(isinstance(v, int) for v in value): + value_str = "x".join(str(v) for v in value) + else: + # Mixed int/float - format floats with limited precision + value_str = "x".join( + f"{v:.4g}" if isinstance(v, float) else str(v) for v in value + ) + else: + # Non-numeric or mixed types - use comma separator + value_str = ",".join(str(v) for v in value) + elif isinstance(value, float): + # Format floats with limited precision to avoid floating point noise + value_str = f"{value:.4g}" + elif isinstance(value, bool): + # Format booleans as 0/1 for compactness + value_str = "1" if value else "0" + elif isinstance(value, bytes): + # Format bytes as hex string (truncated for long values) + hex_str = value.hex() + value_str = hex_str if len(hex_str) <= 16 else f"{hex_str[:16]}..." + else: + # Default: convert to string + value_str = str(value) + + attr_parts.append(f"{key}={value_str}") + + # Build final signature with attributes + attr_sig = f"[{','.join(attr_parts)}]" + + # Add symmetric input signature if applicable + if has_symmetric_inputs: + return f"{op}{attr_sig}<{source_sig}>" + else: + return f"{op}{attr_sig}" + else: + # No attributes - already handled above for symmetric ops + return op + + @staticmethod + def _make_node_signature(ops: list[str]) -> str: + """Create signature from list of node operations. + + Handles single and multiple operations, including symmetric operations. + + Args: + ops: List of operation signatures (may include parameters) + + Returns: + Signature string for the operations + """ + if not ops: + return "" + + if len(ops) == 1: + return ops[0] + + # Multiple operations - create sequential signature + return "->".join(ops) + + @staticmethod + def _join_signatures(signatures: list[str]) -> str: + """Join multiple child signatures. + + Sorts signatures alphabetically to ensure deterministic ordering. + This is critical for pattern matching and comparison. + + Args: + signatures: List of child signatures + + Returns: + Combined signature string with deterministic ordering + """ + if not signatures: + return "" + + if len(signatures) == 1: + return signatures[0] + + # Sort signatures alphabetically for deterministic ordering + # This ensures that parallel/sibling regions always produce + # the same combined signature regardless of traversal order + sorted_sigs = sorted(signatures) + return "+".join(sorted_sigs) diff --git a/modelopt/onnx/quantization/autotune/region_search.py b/modelopt/onnx/quantization/autotune/region_search.py new file mode 100644 index 000000000..62906fd50 --- /dev/null +++ b/modelopt/onnx/quantization/autotune/region_search.py @@ -0,0 +1,2348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Region Search - Hierarchical Region Discovery and Partitioning. + +This module provides sophisticated algorithms for discovering and organizing regions +in ONNX computation graphs. It creates hierarchical region structures that respect +computational patterns like divergence, convergence, and sequential operations. + +**Core Functionality:** +- **Two-Phase Region Discovery**: Combines bottom-up partitioning with top-down refinement +- **Pattern Recognition**: Identifies divergence/convergence patterns in computation flow +- **Hierarchical Structure**: Creates COMPOSITE regions containing LEAF child regions +- **Boundary Computation**: Automatically determines region input/output tensors +- **Graph Analysis**: Pre-computes reachability and data flow information + +**Key Algorithms:** + +1. **Bottom-Up Partitioning (RegionPartitioner)**: + - Traverses graph from inputs to outputs + - Identifies divergent nodes where computation branches + - Finds convergence points where branches rejoin + - Creates initial LEAF regions based on these patterns + +2. **Top-Down Refinement (TopDownRegionBuilder)**: + - Merges converged sub-patterns within regions + - Splits long sequences into optimal-sized regions + - Creates hierarchical COMPOSITE region structures + - Respects operation boundaries (Conv, Gemm, etc.) + +3. **Combined Strategy (CombinedRegionSearch)**: + - Orchestrates both phases for comprehensive region discovery + - Produces well-formed hierarchical regions covering entire graph + +**Region Types:** +- **LEAF regions**: Contain actual graph nodes (basic building blocks) +- **COMPOSITE regions**: Contain child regions (hierarchical organization) +- **ROOT region**: Single region containing all graph nodes (for analysis) + +**Use Cases:** +- Graph partitioning for distributed execution +- Identifying optimization boundaries for quantization/pruning +- Creating hierarchical abstractions of computation +- Analyzing graph structure and computational patterns + +**Key Classes:** +- **RegionSearchBase**: Base class with common graph analysis utilities +- **CombinedRegionSearch**: Main two-phase region discovery algorithm +- **RegionPartitioner**: Bottom-up partitioning based on divergence/convergence +- **TopDownRegionBuilder**: Top-down refinement creating hierarchical structure +""" + +import argparse +import logging +import sys +from collections import Counter, deque + +import onnx +import onnx_graphsurgeon as gs + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.insertion_points import has_quantizable_operations +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +# Module logger +logger = logging.getLogger(__name__) + + +def enable_debug(): + """Enable debug-level logging for the region search module.""" + global logger + logger.setLevel(logging.DEBUG) + + +DEFAULT_MAX_STEPS = 10 +DEFAULT_MAX_NODES_TO_SHOW = 20 + + +class RegionSearchBase: + """Base class for region search algorithms providing common graph analysis utilities. + + This class serves as a foundation for region-based graph analysis algorithms by + providing essential data structures and methods for: + - Graph traversal and reachability analysis + - Divergence/convergence pattern detection + - Region boundary computation + - Tensor flow tracking + + **Core Data Structures:** + - **tensor_users_map**: Maps tensor names to node indices that consume them. + Used to efficiently find divergence points and track data flow. + - **forward_reachable_nodes_map**: Pre-computed forward reachability for all nodes. + Maps each node to all nodes reachable from it (with distances). + - **root**: Root region containing all graph nodes, used as search space. + + **Key Algorithms:** + - **Divergence Detection**: Identifies nodes whose outputs branch to multiple consumers + - **Convergence Detection**: Finds nodes where multiple branches rejoin + - **Boundary Computation**: Determines input/output tensors for regions + - **Reachability Analysis**: Computes forward-reachable nodes with distances + + **Design Pattern:** + This is a base class meant to be subclassed. Subclasses implement specific + region formation strategies (e.g., bottom-up partitioning, top-down refinement) + while reusing the common analysis utilities provided here. + + **Performance:** + Pre-computation in __init__ scales with graph size: + - tensor_users_map: O(E) where E = number of edges + - forward_reachable_nodes_map: O(N * (N + E)) where N = number of nodes + + For large graphs, initialization may take significant time but enables + efficient queries during region formation. + + Attributes: + graph: The ONNX computation graph (onnx_graphsurgeon.Graph) + root: Root region containing all nodes in the graph + tensor_users_map: Mapping from tensor names to consuming node indices + forward_reachable_nodes_map: Pre-computed forward reachability for all nodes + + Example: + >>> # Typically used as a base class + >>> class MyRegionSearch(RegionSearchBase): + ... def find_regions(self): + ... # Use inherited utilities like _is_node_divergent() + ... pass + """ + + def __init__( + self, graph: gs.Graph, root: Region | None = None, max_steps: int = DEFAULT_MAX_STEPS + ): + """Initialize the base region search with graph analysis. + + Performs pre-computation of essential data structures for efficient + region analysis: + 1. Creates or validates root region containing all nodes + 2. Builds tensor-to-users mapping for divergence detection + 3. Pre-computes forward reachability for convergence detection + + Args: + graph: The ONNX graph to analyze (onnx_graphsurgeon.Graph) + root: Optional root region. If None, creates one containing all nodes. + max_steps: Maximum distance for forward reachability pre-computation. + Limits memory usage and computation time for large graphs. + + Note: + Initialization time scales with graph complexity. For graphs with + thousands of nodes, this may take several seconds. + """ + self.graph = graph + if root is None: + root = self._build_root_region() + self.root = root + self.tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.forward_reachable_nodes_map = self._build_forward_reachable_nodes_map( + max_steps=max_steps + ) + + def _build_root_region(self) -> Region: + """Create a root region containing all nodes in the graph. + + The root region serves as the universal search space for region + formation algorithms. It represents the entire computation graph + as a single region before any partitioning. + + Returns: + Region of type ROOT containing all graph nodes. + """ + root = Region(region_id=0, level=0, region_type=RegionType.ROOT) + for node_idx in range(len(self.graph.nodes)): + root.add_node(node_idx) + for tensor_name in root.get_inputs(): + root.add_input(tensor_name) + for tensor_name in root.get_outputs(): + root.add_output(tensor_name) + return root + + def _is_tensor_divergent(self, tensor_name: str) -> bool: + """Check if a tensor is consumed by multiple nodes (divergent). + + A divergent tensor indicates branching in the computation graph, + where one operation's output feeds into multiple downstream operations. + + Args: + tensor_name: Name of the tensor to check + + Returns: + True if tensor has more than one consumer, False otherwise + """ + return len(self.tensor_users_map.get(tensor_name, [])) > 1 + + def _is_node_divergent(self, node_idx: int) -> bool: + """Check if a node has outputs that branch to multiple consumers. + + A divergent node is one that produces outputs consumed by multiple + downstream nodes, creating branches in the computation graph. These + nodes are important boundaries for region formation. + + **Significance:** + - Divergent nodes often represent natural region boundaries + - They indicate where computation splits into parallel paths + - Useful for identifying opportunities for parallel optimization + + Args: + node_idx: Index of the node to check + + Returns: + True if the node has at least one output consumed by multiple nodes, + False otherwise or if node is not in root region. + + Example: + >>> # Node 10 outputs tensor "X" consumed by nodes 11 and 12 + >>> _is_node_divergent(10) # Returns True + """ + if node_idx not in self.root.get_nodes(): + logger.debug(f"Node {node_idx} not in root region") + return False + + node = self.graph.nodes[node_idx] + divergent_outputs = [ + out.name for out in node.outputs if self._is_tensor_divergent(out.name) + ] + is_divergent = len(divergent_outputs) > 0 + + if is_divergent: + logger.debug( + f"Divergent node {node_idx} ({node.op}): {len(divergent_outputs)} branches" + ) + + return is_divergent + + def _compute_forward_reachable_nodes( + self, start_node_idx: int, max_steps: int + ) -> dict[int, int]: + """Compute all nodes reachable forward from a starting node with distances. + + Uses breadth-first search (BFS) to find all nodes reachable by following + forward edges (data flow direction) from the start node, up to a maximum + distance. Records the shortest-path distance to each reachable node. + + **Algorithm:** + 1. Initialize with start node at distance 0 + 2. For each node in queue: + - If at max distance, skip + - For each output tensor: + - For each consumer of that tensor: + - If not yet visited, add to queue with distance+1 + + **Use Cases:** + - Convergence detection: Find where branches rejoin + - Region size estimation: Count nodes in forward cone + - Dependency analysis: Understand downstream impact + + Args: + start_node_idx: Index of node to start search from + max_steps: Maximum forward distance to explore + + Returns: + Dictionary mapping reachable node indices to their distances from start. + Includes start_node_idx mapped to distance 0. + + Example: + >>> # Find all nodes within 5 steps forward of node 10 + >>> reachable = _compute_forward_reachable_nodes(10, 5) + >>> reachable[10] # 0 (start node) + >>> reachable[15] # 3 (if node 15 is 3 steps away) + """ + reachable: dict[int, int] = {start_node_idx: 0} + queue: deque[tuple[int, int]] = deque([(start_node_idx, 0)]) + while queue: + current_node_idx, distance = queue.popleft() + if distance >= max_steps: + continue + current_node = self.graph.nodes[current_node_idx] + for output in current_node.outputs: + if output.name not in self.tensor_users_map: + continue + for next_node_idx in self.tensor_users_map[output.name]: + if next_node_idx not in reachable: + reachable[next_node_idx] = distance + 1 + queue.append((next_node_idx, distance + 1)) + return reachable + + def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[int, int]]: + """Pre-compute forward reachability for all nodes in the graph. + + This is a key optimization that enables efficient convergence detection. + By pre-computing forward reachability once, we can quickly answer queries + like "Can node A reach node B?" and "What is the distance from A to B?" + + **Complexity:** + - Time: O(N * (N + E)) where N = nodes, E = edges + - Space: O(N²) in worst case for dense graphs + + **Trade-off:** + Pre-computation takes time upfront but dramatically speeds up convergence + detection, which would otherwise require repeated BFS traversals. + + Args: + max_steps: Maximum forward distance to pre-compute for each node. + Limits both time and space complexity. + + Returns: + Nested dictionary where outer key is start node index, inner key is + reachable node index, and value is shortest-path distance. + + Example: + >>> map = _build_forward_reachable_nodes_map(10) + >>> map[5][8] # Distance from node 5 to node 8 + 3 + >>> 12 in map[5] # Can node 5 reach node 12? + True + """ + logger.debug(f"Building forward reachability map (max_steps={max_steps})...") + forward_reachable_nodes_map: dict[int, dict[int, int]] = {} + for node_idx in self.root.get_nodes(): + forward_reachable_nodes_map[node_idx] = self._compute_forward_reachable_nodes( + node_idx, max_steps + ) + + total_reachable = sum(len(reachable) for reachable in forward_reachable_nodes_map.values()) + avg_reachable = total_reachable / len(self.root.get_nodes()) if self.root.get_nodes() else 0 + logger.debug(f"Reachability map complete: avg {avg_reachable:.1f} reachable nodes per node") + return forward_reachable_nodes_map + + def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: + """Find convergence point and intermediate nodes for a divergent node. + + Given a divergent node (where computation branches), this method finds: + 1. The convergence node: Where the branches rejoin + 2. All nodes between divergence and convergence + + **Algorithm:** + 1. Identify all branches from the divergent node + 2. Find nodes reachable from all branches (common nodes) + 3. Select nearest common node that forms a valid region + 4. Compute all nodes between divergence and convergence + + **Convergence Criteria:** + A valid convergence node must: + - Be reachable from all branches + - Form a contiguous region (no nodes escape the region) + - Be the nearest such node (minimize region size) + + **Region Validity:** + A region is valid if all nodes within it either stay in the region + or directly reach the convergence point. No node should reach outside + the region before reaching the convergence point. + + Args: + node_idx: Index of the divergent node to find convergence for + + Returns: + Tuple of (converge_node_idx, visited_nodes): + - converge_node_idx: Index of convergence node, or None if not found + - visited_nodes: Set of node indices between divergence and convergence + + Example: + >>> # Node 10 branches to 11 and 12, which rejoin at node 15 + >>> converge_idx, visited = _find_converge_nodes(10) + >>> converge_idx # 15 + >>> visited # {10, 11, 12, 13, 14} (all nodes in between) + """ + node = self.graph.nodes[node_idx] + logger.debug(f"Finding convergence for node {node_idx} ({node.op})") + + branches: list[int] = [] + for output in node.outputs: + if output.name in self.tensor_users_map: + branches.extend(self.tensor_users_map[output.name]) + + seen: set[int] = set() + unique_branches: list[int] = [] + for branch_idx in branches: + if branch_idx not in seen: + seen.add(branch_idx) + unique_branches.append(branch_idx) + branches = unique_branches + + logger.debug(f" {len(branches)} unique branches found") + + # Need at least 2 branches for convergence to be meaningful + if len(branches) <= 1: + logger.debug(" Insufficient branches for convergence") + return None, set() + + # ===================================================================== + # STEP 1: Find Common Reachable Nodes (Potential Convergence Points) + # ===================================================================== + # A valid convergence node must be reachable from ALL branches. + # Use pre-computed forward reachability for efficiency. + + # Collect forward-reachable nodes for each branch + branch_reachable: list[dict[int, int]] = [] + for branch_idx in branches: + reachable = self.forward_reachable_nodes_map.get(branch_idx, {}) + branch_reachable.append(reachable) + + if not branch_reachable: + logger.debug(" No reachable nodes from branches") + return None, set() + + # Find intersection: nodes reachable from ALL branches + # These are the only candidates for convergence points + common_nodes = set(branch_reachable[0].keys()) + for reachable in branch_reachable[1:]: + common_nodes.intersection_update(reachable.keys()) + + logger.debug(f" {len(common_nodes)} common nodes found") + + # Remove the divergent node itself (not a convergence point) + common_nodes.discard(node_idx) + + if not common_nodes: + logger.debug(" No valid convergence candidates") + return None, set() + + # ===================================================================== + # STEP 2: Select Best Convergence Node with Region Validity Check + # ===================================================================== + # Not all common nodes make good convergence points. We need to ensure + # the region formed is "valid" - i.e., contiguous with no escaping edges. + # + # Region validity criterion: + # For every node R in the region (between divergence and candidate): + # For every node T reachable from R: + # If T is outside the region: + # T must be at least as far from R as the candidate is + # (i.e., R doesn't "escape" before reaching candidate) + + converge_node_idx: int | None = None + min_max_distance = float("inf") + + # Get all nodes reachable from the divergent node + reachable_from_start = self.forward_reachable_nodes_map.get(node_idx, {}) + + # Evaluate each candidate convergence point + for candidate_idx in common_nodes: + # --------------------------------------------------------------- + # Define the potential region: nodes between start and candidate + # --------------------------------------------------------------- + # Region = nodes reachable from start BUT NOT reachable from candidate + # (candidate acts as the boundary) + region_nodes: set[int] = set() + region_nodes.update(set(reachable_from_start.keys())) + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + # Remove nodes beyond the candidate (not in our region) + region_nodes.difference_update(set(reachable_from_candidate.keys())) + + # --------------------------------------------------------------- + # Validate region: Check for "escaping" edges + # --------------------------------------------------------------- + # A region is invalid if any node inside can reach a node outside + # BEFORE reaching the convergence point. This would mean the region + # has edges that "leak out" and isn't properly bounded. + broken_region = False + + # Check each node in the proposed region + for rnode_index in region_nodes: + # Get all nodes reachable from this region node + reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {}) + + # Distance from this node to the candidate (convergence) + rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf")) + + # Check all nodes reachable from this region node + for test_node_idx in reachable_from_rnode: + # Skip nodes that are inside the region (they're fine) + if test_node_idx in region_nodes: + continue + + # test_node is OUTSIDE the region. Check if it's "escaping" + # An escaping edge: region_node reaches test_node BEFORE candidate + rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf")) + + # If either distance is infinite, region is broken + # (indicates disconnected components or unreachable convergence) + if rnode_to_test_distance == float( + "inf" + ) or rnode_to_candidate_distance == float("inf"): + broken_region = True + break + + # If test_node is closer than candidate, we have an escape! + # This means computation flows OUT of region before converging + if rnode_to_test_distance < rnode_to_candidate_distance: + broken_region = True + break + + if broken_region: + break + + # Skip this candidate if region is invalid + if broken_region: + continue + + # --------------------------------------------------------------- + # Valid candidate! Check if it's the nearest one + # --------------------------------------------------------------- + # We want the closest convergence point to minimize region size + # "Distance" = maximum distance from any branch to convergence + max_distance = max(reachable[candidate_idx] for reachable in branch_reachable) + + if max_distance < min_max_distance: + min_max_distance = max_distance + converge_node_idx = candidate_idx + + # If no valid convergence found, this divergence has no convergence + if converge_node_idx is None: + logger.debug(" No valid convergence found") + return None, set() + + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Convergence at node {converge_node_idx} ({converge_node.op}), distance {min_max_distance}" + ) + + # ===================================================================== + # STEP 3: Compute All Nodes Between Divergence and Convergence + # ===================================================================== + # Now that we have a valid convergence point, we need to identify ALL + # nodes that should be included in the convergence region. + # + # A node is "between" divergence and convergence if: + # 1. It's reachable from the divergence node (on some path from divergence) + # 2. The convergence node is reachable from it (on some path to convergence) + # 3. It's not the convergence node itself (convergence is the boundary) + # + # This captures all the "interior" nodes of the funnel/diamond pattern, + # including all branches and intermediate computations. + + visited_nodes: set[int] = set() + + # Check each node reachable from the divergent node + for candidate_idx in reachable_from_start: + # Skip the convergence node itself (it's the boundary, not interior) + if candidate_idx == converge_node_idx: + continue + + # Check if this node can reach the convergence node + # If yes, it's on a path from divergence to convergence + reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) + if converge_node_idx in reachable_from_candidate: + # This node is between divergence and convergence! + visited_nodes.add(candidate_idx) + + logger.debug(f" {len(visited_nodes)} nodes between divergence and convergence") + return converge_node_idx, visited_nodes + + def _max_distance_to_nodes(self, src_idx: int, dst_indices: set[int]) -> int: + """Compute maximum distance from a source node to a set of destination nodes. + + Uses pre-computed forward reachability to efficiently find the maximum + shortest-path distance from src_idx to any node in dst_indices. + + **Use Cases:** + - Determine if a convergence region is within acceptable size limits + - Measure the "spread" of nodes in a potential region + - Validate region compactness constraints + + Args: + src_idx: Source node index + dst_indices: Set of destination node indices + + Returns: + Maximum distance from src to any node in dst_indices. + Returns 0 if dst_indices is empty or no nodes are reachable. + + Example: + >>> # Check if all nodes are within 10 steps + >>> max_dist = _max_distance_to_nodes(start_node, candidate_nodes) + >>> if max_dist <= 10: + ... # Region is compact enough + """ + max_distance = 0 + for dst_idx in dst_indices: + reachable = self.forward_reachable_nodes_map.get(src_idx, {}) + if dst_idx in reachable: + max_distance = max(max_distance, reachable[dst_idx]) + + logger.debug( + f"Max distance from node {src_idx}: {max_distance} steps to {len(dst_indices)} nodes" + ) + return max_distance + + def compute_region_boundaries(self, region: Region, include_constant: bool = False) -> None: + """Compute input and output tensor boundaries for a region. + + **Algorithm:** + 1. Collect all tensors consumed by region nodes (potential inputs) + 2. Collect all tensors produced by region nodes (potential outputs) + 3. Input = consumed tensors NOT produced by region nodes + 4. Output = produced tensors consumed by nodes OUTSIDE the region + + This accurately captures the data flow boundaries of the region. + + Args: + region: The region to compute boundaries for + """ + node_indices = region.get_all_nodes_recursive() + all_inputs: set[str] = set() + all_outputs: set[str] = set() + internal_tensors: set[str] = set() + + # First pass: collect all inputs and outputs + for node_idx in node_indices: + if node_idx >= len(self.graph.nodes): + continue + node = self.graph.nodes[node_idx] + # Collect input tensors + for input_tensor in node.inputs: + if isinstance(input_tensor, gs.Constant) and not include_constant: + continue + all_inputs.add(input_tensor.name) + # Collect output tensors + for output_tensor in node.outputs: + all_outputs.add(output_tensor.name) + internal_tensors.add(output_tensor.name) + + # Region inputs = consumed tensors not produced internally + region_inputs = all_inputs - internal_tensors + + # Region outputs = produced tensors consumed externally + region_outputs: set[str] = set() + for node_idx in node_indices: + if node_idx >= len(self.graph.nodes): + continue + node = self.graph.nodes[node_idx] + for output_tensor in node.outputs: + tensor_name = output_tensor.name + if tensor_name not in self.tensor_users_map: + region_outputs.add(tensor_name) + continue + # Check if any consumer is outside the region + has_external_consumer = False + # Get consumer nodes from tensor_users_map + consumer_indices = self.tensor_users_map[tensor_name] + for consumer_idx in consumer_indices: + if consumer_idx not in node_indices: + # Consumer is outside the region + has_external_consumer = True + break + if has_external_consumer: + region_outputs.add(tensor_name) + # Also check if this is a graph output + if output_tensor in self.graph.outputs: + region_outputs.add(tensor_name) + + # Add to region + region.inputs = sorted(region_inputs) + region.outputs = sorted(region_outputs) + + logger.debug( + f"Computed boundaries: {len(region_inputs)} inputs, {len(region_outputs)} outputs" + ) + + def print_tree( + self, + region: Region | None = None, + indent: int = 0, + max_nodes_to_show: int = DEFAULT_MAX_NODES_TO_SHOW, + file=None, + ) -> None: + """Print hierarchical region tree in human-readable text format. + + Recursively prints the region hierarchy with indentation showing depth. + For each region, displays: + - ID, level, and type (LEAF/COMPOSITE/ROOT) + - Node counts (direct and recursive) + - I/O tensor counts + - Sample of nodes in the region (up to max_nodes_to_show) + - Child regions (recursively) + + Args: + region: Region to print (None defaults to root) + indent: Current indentation level (0 = root) + max_nodes_to_show: Maximum nodes to display per region (default: 5) + file: Output file object (None defaults to stdout) + + Example: + >>> builder.print_tree() + ├─ Region 0 (Level 0, Type: ROOT) + │ ├─ Direct nodes: 0 + │ └─ Children: 2 + │ ├─ Region 1 (Level 1, Type: COMPOSITE) + ... + """ + region = region or self.root + if region is None: + return + + if file is None: + file = sys.stdout + + prefix = " " * indent + + # Print region header + region_type = region.get_type().value + print( + f"{prefix}├─ Region {region.get_id()} (Level {region.get_level()}, Type: {region_type})", + file=file, + ) + + # Print region size info + direct_nodes = region.get_nodes() + total_nodes = region.get_all_nodes_recursive() + num_children = len(region.get_children()) + + print(f"{prefix}│ ├─ Direct nodes: {len(direct_nodes)}", file=file) + print(f"{prefix}│ ├─ Total nodes (recursive): {len(total_nodes)}", file=file) + print(f"{prefix}│ ├─ Children: {num_children}", file=file) + + # Print region I/O info + inputs = region.get_inputs() + outputs = region.get_outputs() + print(f"{prefix}│ ├─ Inputs: {len(inputs)} tensors", file=file) + if inputs: + for tensor_name in list(inputs)[:max_nodes_to_show]: + print(f"{prefix}│ │ - {tensor_name}", file=file) + if len(inputs) > max_nodes_to_show: + print(f"{prefix}│ │ ... and {len(inputs) - max_nodes_to_show} more", file=file) + print(f"{prefix}│ └─ Outputs: {len(outputs)} tensors", file=file) + if outputs: + for tensor_name in list(outputs)[:max_nodes_to_show]: + print(f"{prefix}│ - {tensor_name}", file=file) + if len(outputs) > max_nodes_to_show: + print(f"{prefix}│ ... and {len(outputs) - max_nodes_to_show} more", file=file) + + # Print direct nodes in this region (if any) + if direct_nodes: + print(f"{prefix}│", file=file) + print(f"{prefix}│ Nodes in this region:", file=file) + nodes_list = sorted(direct_nodes)[:max_nodes_to_show] + for node_idx in nodes_list: + if node_idx < len(self.graph.nodes): + node = self.graph.nodes[node_idx] + print( + f"{prefix}│ - Node {node_idx}: {node.op} (name: {node.name})", file=file + ) + + if len(direct_nodes) > max_nodes_to_show: + print( + f"{prefix}│ ... and {len(direct_nodes) - max_nodes_to_show} more nodes", + file=file, + ) + + # Print children (recursively) + children = region.get_children() + if children: + print(f"{prefix}│", file=file) + print(f"{prefix}│ Child regions:", file=file) + for child_index, child in enumerate(children): + print(f"{prefix}│", file=file) + self.print_tree(child, indent + 1, max_nodes_to_show, file) + + +class RegionPartitioner(RegionSearchBase): + """Bottom-up graph partitioner that creates initial regions based on divergence patterns. + + This class implements Phase 1 of the combined region search strategy. It performs + a systematic traversal of the computation graph from inputs to outputs, identifying + natural boundaries for region formation based on computation flow patterns. + + **Core Strategy:** + Partitions the graph by analyzing three types of computational patterns: + + 1. **Divergent Nodes with Convergence:** + - Nodes whose outputs branch to multiple paths (divergence) + - Paths that eventually rejoin at a common node (convergence) + - Creates a single region encompassing divergence + branches + convergence + - Example: A → (B,C) → D creates region containing {A, B, C, D} + + 2. **Divergent Nodes without Convergence:** + - Nodes whose outputs branch but never rejoin + - Creates a single-node "orphan" region for the divergent node + - Example: A → (B,C) with no convergence creates region {A} + + 3. **Linear Sequences:** + - Chains of non-divergent nodes (simple sequential computation) + - Groups entire sequence into one region + - Example: A → B → C → D creates region {A, B, C, D} + + **Algorithm Overview:** + ``` + For each node in graph order: + If already visited: skip + If divergent: + Find convergence point + If convergence exists within threshold: + Create region with all nodes between divergence and convergence + Else: + Create single-node region (orphan) + Else (non-divergent): + Build sequence: follow chain until hitting divergent node + Create region containing entire sequence + ``` + + **Key Features:** + - **Complete Coverage:** Every node is assigned to exactly one region + - **Convergence Detection:** Uses pre-computed reachability for efficiency + - **Distance Threshold:** Limits region size to DEFAULT_MAX_STEPS + - **Sequential Processing:** Respects data flow order for natural groupings + + **Region Types Created:** + All regions created by this class are LEAF regions (level 0). Higher-level + structure is created later by TopDownRegionBuilder. + + **State Management:** + - **visited_nodes:** Tracks which nodes have been assigned to regions + - **current_region:** Region being built (commit when complete) + - **regions:** List of completed regions + - **current_region_id:** Counter for unique region IDs + + **Output:** + A list of LEAF regions that partition the entire graph. These regions + serve as input to Phase 2 (TopDownRegionBuilder) for refinement. + + **Example:** + ```python + partitioner = RegionPartitioner(graph) + initial_regions = partitioner.partition_graph() + + # Analyze results + print(f"Created {len(initial_regions)} regions") + print(f"Covered {len(partitioner.visited_nodes)} / {len(graph.nodes)} nodes") + + # Typical output for a ResNet layer: + # - Conv node → orphan region (diverges to BN and skip path) + # - BN → ReLU sequence → sequential region + # - Add (convergence) → orphan or part of next sequence + ``` + + **Performance:** + - Time: O(N) where N = number of nodes (each visited once) + - Space: O(N) for visited_nodes set and region storage + + Attributes: + regions: List of completed LEAF regions + current_region: Region currently being built (None if between regions) + current_region_id: Counter for assigning unique region IDs + visited_nodes: Set of node indices already assigned to regions + + See Also: + TopDownRegionBuilder: Phase 2 refinement of partitioner output + CombinedRegionSearch: Orchestrates both phases + """ + + def __init__(self, graph: gs.Graph): + """Initialize the partitioner with a computation graph. + + Sets up necessary data structures and inherits graph analysis utilities + from RegionSearchBase (tensor users map, reachability, etc.). + + Args: + graph: The ONNX graph to partition (onnx_graphsurgeon.Graph) + """ + super().__init__(graph, root=None) + self.regions: list[Region] = [] + self.current_region: Region | None = None + self.current_region_id: int = 0 + self.visited_nodes: set[int] = set() + + def _append_node_to_region(self, node_idx: int): + """Add a node to the current region, creating a new region if needed. + + This is the primary method for building regions incrementally. If no + region is currently active, creates a new LEAF region. Then adds the + specified node to that region. + + **Usage Pattern:** + Typically called multiple times to build up a region, then followed + by _commit_region() to finalize and store the completed region. + + Args: + node_idx: Index of node to add to current region + + Side Effects: + - Creates new region if current_region is None + - Increments current_region_id when creating new region + - Adds node to current_region + """ + node = self.graph.nodes[node_idx] + if self.current_region is None: + self.current_region = Region( + region_id=self.current_region_id, level=0, region_type=RegionType.LEAF + ) + logger.debug(f"Started region {self.current_region_id}") + self.current_region_id += 1 + + self.current_region.add_node(node_idx) + logger.debug( + f" Added node {node_idx} ({node.op}), region size: {self.current_region.get_size()}" + ) + + def _commit_region(self): + """Finalize and store the current region being built. + + Completes region construction by: + 1. Computing input/output tensor boundaries + 2. Adding region to the completed regions list + 3. Resetting current_region to None for next region + + **Boundary Computation:** + Determines which tensors flow into and out of the region based on + which nodes produce/consume them. This is essential for understanding + region dependencies. + + **Post-Conditions:** + - current_region is added to regions list + - current_region is reset to None + - Region has computed input/output tensor lists + + Side Effects: + - Appends current_region to self.regions + - Sets current_region to None + - Logs region commit with size info + """ + if self.current_region is not None: + region_size = self.current_region.get_size() + region_id = self.current_region.id + + # Compute input/output tensor boundaries + self.compute_region_boundaries(self.current_region) + + self.regions.append(self.current_region) + logger.debug( + f"Committed region {region_id}: {region_size} nodes (total: {len(self.regions)})" + ) + self.current_region = None + else: + logger.debug("No region to commit") + + def _build_sequence_from_node(self, node_idx: int, max_nodes: int = -1): + """Build a region from a linear sequence of non-divergent nodes. + + Starting from a non-divergent node, follows the forward chain of nodes, + adding each non-divergent node to the current region. Stops when hitting: + - A divergent node (branches to multiple paths) + - A node already visited + - End of graph + + **Algorithm:** + ``` + queue = [start_node] + while queue not empty: + node = dequeue() + if node is divergent: + stop (this node will be handled separately) + else: + add node to region + add all successors to queue + commit region + ``` + + **Example:** + For graph: Conv → BN → ReLU → MaxPool (no branching) + Creates one region containing all four nodes. + + **Stopping Conditions:** + - Divergent node encountered (boundary for this region) + - All successors already visited + - No more forward connections + + Args: + node_idx: Index of starting node (must be non-divergent) + + Side Effects: + - Adds nodes to current_region via _append_node_to_region + - Marks nodes as visited + - Commits completed region + + Note: + Always commits the region at the end, even if only one node was added. + """ + start_node = self.graph.nodes[node_idx] + logger.debug(f"Building sequence from node {node_idx} ({start_node.op})") + + queue: deque[int] = deque([node_idx]) + nodes_added = 0 + + while len(queue) > 0: + current_node_idx = queue.popleft() + current_node = self.graph.nodes[current_node_idx] + + if not self._is_node_divergent(current_node_idx): + self._append_node_to_region(current_node_idx) + self.visited_nodes.add(current_node_idx) + nodes_added += 1 + + # Find successors + successor_count = 0 + for output_tensor in current_node.outputs: + if output_tensor.name in self.tensor_users_map: + successors = self.tensor_users_map[output_tensor.name] + successor_count += len(successors) + queue.extend(successors) + else: + self._append_node_to_region(current_node_idx) + nodes_added += 1 + logger.debug(f" Stopped at divergent node {current_node_idx} ({current_node.op})") + + if max_nodes > 0 and nodes_added >= max_nodes: + logger.debug(" Max nodes reached") + break + + logger.debug(f"Sequence complete: {nodes_added} nodes") + + def _build_small_converged_region( + self, start_node_idx: int, converge_node_idx: int, visited_nodes: set[int] + ): + r"""Create a region encompassing divergence, branches, and convergence. + + Builds a single region containing: + - The divergent node (where branches split) + - All nodes in the branches + - The convergence node (where branches rejoin) + + This creates a "diamond" or "funnel" shaped region that captures + parallel computation paths and their merge point. + + **Structure:** + ``` + start (divergent) + / \ + path1 path2 (visited_nodes) + \\ / + convergence + ``` + + **Example:** + For ResNet skip connection: + - start_node: Output of previous layer (branches) + - visited_nodes: {Conv, BN, ReLU, Conv, BN} (main path) + - converge_node: Add operation (merges with skip) + + Args: + start_node_idx: The divergent node where branches begin + converge_node_idx: Where branches rejoin (currently unused but kept for API) + visited_nodes: All nodes between divergence and convergence + + Side Effects: + - Adds all nodes to current region + - Marks all nodes as visited + - Commits the completed region + """ + visited_nodes.remove(start_node_idx) + for node_idx in sorted(visited_nodes): + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + if not self._is_node_divergent(converge_node_idx): + self._append_node_to_region(converge_node_idx) + self.visited_nodes.add(converge_node_idx) + self._build_sequence_from_node(converge_node_idx, max_nodes=3) + + def _build_region_from_node(self, node_idx: int): + """Process a single node and create appropriate region(s) based on its pattern. + + This is the core dispatch method that determines how to handle each node + based on whether it's divergent (branches) or sequential. Implements the + three pattern recognition strategies described in the class documentation. + + **Decision Logic:** + ``` + If node already visited: + Skip (already in a region) + Else if node is divergent: + Try to find convergence point + If convergence found within distance threshold: + Create convergence region (divergence + branches + convergence) + Else: + Create orphan region (just the divergent node) + Else (non-divergent): + Build sequence region (follow chain until divergence) + ``` + + **Pattern 1: Divergent with Convergence (Ideal Case)** + Creates a complete "funnel" region capturing parallel branches: + - Example: ResNet skip connection (Conv branch + identity → Add) + - Condition: converge_node found AND distance < DEFAULT_MAX_STEPS + - Result: One region containing all nodes between divergence and convergence + + **Pattern 2: Divergent without Convergence (Boundary Case)** + Creates a single-node "orphan" region: + - Example: Final layer that branches to multiple outputs + - Condition: No convergence found OR convergence too far away + - Result: Region containing only the divergent node + + **Pattern 3: Sequential Chain (Common Case)** + Creates a region containing linear sequence: + - Example: Conv → BN → ReLU → MaxPool + - Condition: Node is not divergent + - Result: Region containing the full non-divergent chain + + Args: + node_idx: Index of node to process + + Side Effects: + - Marks processed nodes as visited + - Creates and commits region(s) via helper methods + - May recursively process successor nodes (in sequence building) + + Note: + This method is idempotent - calling it multiple times on the same + node has no effect after the first call (due to visited check). + """ + node = self.graph.nodes[node_idx] + + # Skip nodes already assigned to regions + if node_idx in self.visited_nodes: + logger.debug(f"Skipping node {node_idx} ({node.op}): already visited") + return + + logger.debug(f"Processing node {node_idx} ({node.op})") + + # Pattern 1 & 2: Handle divergent nodes + if self._is_node_divergent(node_idx): + logger.debug(" Divergent node, searching for convergence") + + # Attempt to find where branches rejoin + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + + # Check if convergence creates a reasonable-sized region + max_distance = self._max_distance_to_nodes(node_idx, visited_nodes) + + # Pattern 1: Convergence found and region size is acceptable + if converge_node_idx is not None and max_distance < DEFAULT_MAX_STEPS: + converge_node = self.graph.nodes[converge_node_idx] + logger.debug( + f" Creating converged region: {len(visited_nodes)} nodes, " + f"convergence at {converge_node_idx} ({converge_node.op}), distance {max_distance}" + ) + # Create region containing: divergence + all branches + convergence + self._build_small_converged_region(node_idx, converge_node_idx, visited_nodes) + self._commit_region() + # Pattern 2: No convergence or region would be too large + else: + logger.debug(" Creating orphan region for divergent node") + # Create single-node region for this divergent node + # Its successors will be processed separately + self._append_node_to_region(node_idx) + self.visited_nodes.add(node_idx) + self._commit_region() + # Pattern 3: Handle non-divergent (sequential) nodes + else: + logger.debug(" Non-divergent node, building sequence") + # Build region by following the linear chain forward + self._build_sequence_from_node(node_idx) + self._commit_region() + + def partition_graph(self): + """Partition the entire graph into non-overlapping LEAF regions. + + This is the main entry point for bottom-up graph partitioning. Performs + a single pass over all nodes in graph order, creating regions based on + divergence/convergence patterns and sequential chains. + + **Algorithm:** + ``` + For each node in graph (in index order): + If node not yet visited: + Analyze node type (divergent vs sequential) + Create appropriate region(s) for node and its neighborhood + Mark processed nodes as visited + + Result: Complete partitioning where every node belongs to exactly one region + ``` + + **Processing Order:** + Nodes are processed in index order (typically matches graph construction + order / topological-ish order). This tends to group naturally related + operations together. + + **Completeness Guarantee:** + Every node in the graph will be assigned to exactly one region. The + visited_nodes set ensures no node is processed twice, and the loop over + all indices ensures no node is skipped. + + **Region Types Created:** + - Convergence regions: Divergent node + branches + convergence + - Orphan regions: Single divergent node with no close convergence + - Sequence regions: Linear chains of non-divergent nodes + + **Output Quality:** + - Total regions: Typically 10-30% of total nodes (varies by graph) + - Region sizes: Mix of small (1-3 nodes) and medium (5-15 nodes) + - Coverage: 100% of graph nodes + + Returns: + List of LEAF regions that partition the entire graph. + Each node appears in exactly one region. + Regions are stored in self.regions and also returned. + + Side Effects: + - Populates self.regions with created regions + - Populates self.visited_nodes with all node indices + - Logs progress and statistics + + Example: + >>> partitioner = RegionPartitioner(graph) + >>> regions = partitioner.partition_graph() + >>> # Verify complete coverage + >>> all_nodes = set() + >>> for region in regions: + ... all_nodes.update(region.get_nodes()) + >>> assert all_nodes == set(range(len(graph.nodes))) + + Performance: + - Time: O(N) where N = number of nodes (each visited once) + - Space: O(N) for visited set and region storage + """ + logger.info(f"Partitioning graph ({len(self.graph.nodes)} nodes)") + logger.debug( + f"Initial state: {len(self.visited_nodes)} visited, {len(self.regions)} regions" + ) + + # Main partitioning loop: process each node in graph order + for node_idx in range(len(self.graph.nodes)): + self._build_region_from_node(node_idx) + + # Log completion and coverage statistics + coverage_pct = ( + 100 * len(self.visited_nodes) / len(self.graph.nodes) if self.graph.nodes else 0 + ) + logger.info( + f"Partitioning complete: {len(self.regions)} regions, " + f"{len(self.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + # Log summary statistics about region sizes + if self.regions: + region_sizes = [r.get_size() for r in self.regions] + avg_size = sum(region_sizes) / len(region_sizes) + min_size = min(region_sizes) + max_size = max(region_sizes) + logger.debug(f"Region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + + return self.regions + + +class TopDownRegionBuilder(RegionSearchBase): + """Top-down region refiner that creates hierarchical structure from initial regions. + + This class implements Phase 2 of the combined region search strategy. It takes + a region created by RegionPartitioner and refines it by: + 1. Identifying and merging converged sub-patterns + 2. Splitting long sequences into optimal sub-regions + 3. Creating a hierarchical COMPOSITE region structure + + **Core Strategy:** + Starting with a flat LEAF region, creates a hierarchy by: + + **Step 1: Merge Converged Regions** + - Identifies divergent nodes within the region + - Finds their convergence points + - Groups divergence+branches+convergence into sub-regions + - Leaves remaining nodes for sequence splitting + + **Step 2: Split Sequence Regions** + - Takes ungrouped nodes (not part of converged patterns) + - Splits into individual node regions initially + - Merges adjacent nodes if they form producer-consumer chains + - Avoids merging boundary operations (Conv, Gemm, etc.) + - Limits region size to prevent overly large groups + + **Step 3: Create Composite** + - Wraps all sub-regions into a single COMPOSITE region + - Computes hierarchical input/output boundaries + - Returns refined region with better internal structure + + **Merging Criteria for Sequences:** + Two adjacent sequence regions can merge if ALL of: + - Producer region's outputs go to exactly one region (simple producer→consumer chain) + - Neither region is too large (< maximum_sequence_region_size nodes each) + - Consumer node is not a boundary operation (Conv, Gemm, etc.) + - Regions are adjacent in data flow (no gaps) + + **Boundary Operations:** + These operation types are treated as boundaries (don't merge across them): + - Conv, ConvTranspose: Convolution layers + - Gemm, MatMul: Matrix multiplications + - AveragePool, MaxPool, GlobalAveragePool, GlobalMaxPool: Pooling + - Resize: Spatial resizing + + **Example Transformation:** + ``` + Input (flat LEAF region): + [Conv, BN, ReLU, Split, Path1_A, Path1_B, Path2_A, Path2_B, Concat] + + Output (hierarchical COMPOSITE region): + COMPOSITE { + LEAF {Conv}, # Boundary op stays alone + LEAF {BN, ReLU}, # Sequence merged + LEAF {Split}, # Divergent node + LEAF {Path1_A, Path1_B, Path2_A, Path2_B, Concat}, # Converged pattern + } + ``` + + **Key Features:** + - **Hierarchical Structure:** Creates parent-child region relationships + - **Pattern-Aware:** Recognizes convergence and sequence patterns + - **Size-Bounded:** Limits region sizes for optimal granularity + - **Boundary-Aware:** Respects operation type boundaries + + **Inputs:** + - A LEAF region from RegionPartitioner (flat list of nodes) + - The graph structure + - Starting region ID for new regions + + **Output:** + - A COMPOSITE region containing LEAF child regions + - Better internal structure reflecting computation patterns + - Same total nodes, but organized hierarchically + + **Usage Pattern:** + ```python + # After partitioning + initial_region = partitioner.regions[0] + + # Refine structure + builder = TopDownRegionBuilder(graph, initial_region, next_region_id=10) + refined_region = builder.build_composite_region() + + # refined_region now has hierarchical structure + print(f"Children: {len(refined_region.get_children())}") + for child in refined_region.get_children(): + print(f" Child {child.get_id()}: {child.get_size()} nodes") + ``` + + **Performance:** + - Time: O(N + E) where N = nodes in region, E = edges between them + - Space: O(N) for temporary data structures + + Attributes: + graph: The computation graph + root: Input region to refine (becomes search space) + regions: Output list of refined regions (typically one COMPOSITE) + next_region_id: Counter for assigning unique IDs to new regions + boundary_op_types: Set of operation types treated as boundaries + maximum_sequence_region_size: Maximum number of nodes allowed in a sequence region + during merging. Prevents overly large regions (default: 10) + + See Also: + RegionPartitioner: Creates initial regions for refinement + CombinedRegionSearch: Orchestrates partitioning and refinement + """ + + def __init__( + self, + graph: gs.Graph, + root: Region, + next_region_id: int = 0, + maximum_sequence_region_size: int = 10, + ): + """Initialize the refiner with a region to refine. + + Args: + graph: The ONNX graph (onnx_graphsurgeon.Graph) + root: The region to refine (typically from RegionPartitioner) + next_region_id: Starting ID for new regions created during refinement + maximum_sequence_region_size: Maximum nodes per sequence region during merging (default: 10) + """ + super().__init__(graph, root=root) + self.regions: list[Region] = [] + self.next_region_id = next_region_id + self.maximum_sequence_region_size = maximum_sequence_region_size + self.boundary_op_types = { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + } + + def _create_leaf_region(self, node_indices: set[int]) -> Region: + """Create a new LEAF region containing specified nodes. + + Helper method to construct a properly configured LEAF region: + - Assigns unique region ID + - Sets level one deeper than root + - Adds all specified nodes + - Computes input/output tensor boundaries + + Args: + node_indices: Set of node indices to include in the region + + Returns: + New LEAF region containing the specified nodes with computed boundaries + + Side Effects: + Increments next_region_id counter + """ + region = Region( + region_id=self.next_region_id, level=self.root.level + 1, region_type=RegionType.LEAF + ) + self.next_region_id += 1 + for node_idx in node_indices: + region.add_node(node_idx) + self.compute_region_boundaries(region) + return region + + def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Region]]: + """Build mapping from tensor names to regions that consume them. + + Similar to tensor_users_map but at the region level instead of node level. + This enables efficient traversal of region dependencies for merging decisions. + + **Purpose:** + Used during sequence splitting to identify producer-consumer chains + between regions. If a tensor is consumed by only one region, that + region might be mergeable with its producer. + + Args: + regions: List of regions to analyze + + Returns: + Dictionary mapping tensor names to lists of regions that consume them. + Tensors with len(consumers) == 1 indicate potential merge opportunities. + + Example: + >>> # Tensor "X" consumed by region 5 and region 7 + >>> usage_map["X"] == [region5, region7] + """ + region_usage_map: dict[str, list[Region]] = {} + for region in regions: + for tensor_name in region.inputs: + if tensor_name not in region_usage_map: + region_usage_map[tensor_name] = [] + region_usage_map[tensor_name].append(region) + return region_usage_map + + def _split_sequence_regions(self, root: Region) -> list[Region]: + """Split a region into smaller sub-regions by merging producer-consumer chains. + + Takes a region and creates optimal sub-regions by: + 1. Initially splitting into individual single-node regions + 2. Traversing in data flow order (following tensor dependencies) + 3. Merging adjacent regions that form simple producer-consumer chains + 4. Respecting boundary operations and size limits + + **Algorithm:** + ``` + 1. Create one LEAF region per node + 2. Build tensor → consuming regions map + 3. Traverse regions in data flow order (BFS from inputs): + For each region: + Check if all outputs go to single consumer region + If yes and merge criteria met: + Merge this region into consumer region + Mark this region as removed + 4. Return regions not marked as removed + ``` + + **Merge Criteria (ALL must be true):** + - All outputs of producer go to exactly one consumer (simple chain) + - Producer region size < maximum_sequence_region_size (avoid overly large regions) + - Consumer region size < maximum_sequence_region_size (avoid overly large regions) + - If consumer is single-node boundary op (Conv, etc.), don't merge + - Consumer not already removed (merged elsewhere) + + **Boundary Operations:** + Single-node regions containing these ops stay independent: + Conv, ConvTranspose, Gemm, MatMul, Pooling ops, Resize + + **Example:** + ``` + Input nodes: [Conv, BN, ReLU, Add] + + Initial: Region{Conv}, Region{BN}, Region{ReLU}, Region{Add} + + Processing: + - Conv outputs only to BN, but Conv is boundary → don't merge + - BN outputs only to ReLU, both small → merge to Region{BN, ReLU} + - Region{BN,ReLU} outputs only to Add → merge to Region{BN, ReLU, Add} + + Final: Region{Conv}, Region{BN, ReLU, Add} + ``` + + **Purpose:** + Groups simple sequential operations while keeping compute-heavy + operations (Conv, Gemm) as separate regions for optimization targeting. + + Args: + root: Region to split (contains nodes to partition into sub-regions) + + Returns: + List of LEAF regions that partition the root's nodes with better + granularity than one-node-per-region or all-in-one. + + Note: + This is the "bottom" of the top-down strategy - splits fine-grained, + then merges selectively based on data flow patterns. + """ + result_regions: list[Region] = [] + removed_regions: set[int] = set() + + # ===================================================================== + # PHASE 1: Split into Single-Node Regions + # ===================================================================== + # Start with maximum granularity: one region per node. + # This gives us the most flexibility for selective merging. + for node_idx in root.get_nodes(): + region = Region( + region_id=self.next_region_id, level=root.level + 1, region_type=RegionType.LEAF + ) + region.add_node(node_idx) + self.compute_region_boundaries(region) + result_regions.append(region) + self.next_region_id += 1 + + # Build map: tensor_name -> regions that consume it + # Enables efficient lookup of producer-consumer relationships + region_usage_map = self._build_region_usage_map(result_regions) + + # ===================================================================== + # PHASE 2: Merge Regions in Data Flow Order + # ===================================================================== + # Traverse regions following data flow (BFS from inputs). + # At each step, check if producer can merge with consumer. + # This creates longer sequences while respecting constraints. + + # Start from root's input tensors and traverse forward + queue = deque(root.get_inputs()) + + while len(queue) > 0: + tensor_name = queue.popleft() + + # Skip tensors not produced by any region in our scope + if tensor_name not in region_usage_map: + continue + + # Process each region consuming this tensor (potential merge targets) + consumers = region_usage_map[tensor_name] + for consumer in consumers: + # Skip regions already merged into others + if consumer.get_id() in removed_regions: + continue + + # ------------------------------------------------------------- + # Check if this consumer can merge with its downstream region + # ------------------------------------------------------------- + # Merging criteria: ALL outputs go to same single region + common_use_region = None + can_merge = True + + # Check all outputs of the consumer region + for output_tensor in consumer.outputs: + # Add output to queue for continued traversal + queue.append(output_tensor) + + # Check if output has consumers in our region set + if output_tensor not in region_usage_map: + # Output goes outside (or nowhere) - can't merge + can_merge = False + break + + # Get regions consuming this output + use_regions = region_usage_map[output_tensor] + + # Must go to exactly ONE region (simple chain) + if len(use_regions) != 1: + # Branches to multiple regions - can't merge + can_merge = False + break + + # Check if all outputs go to the SAME region + if common_use_region is None: + # First output: remember its consumer + common_use_region = use_regions[0] + elif common_use_region != use_regions[0]: + # Different outputs go to different regions - can't merge + can_merge = False + break + + # No valid downstream region to merge with + if common_use_region is None or common_use_region.get_id() in removed_regions: + can_merge = False + continue + + # ------------------------------------------------------------- + # Apply Additional Constraints + # ------------------------------------------------------------- + + # Constraint 1: Limit the number of boundary operations after merge + nodes_after_merge = set() + nodes_after_merge.update(consumer.get_nodes()) + nodes_after_merge.update(common_use_region.get_nodes()) + node_ops = [self.graph.nodes[idx].op for idx in nodes_after_merge] + boundary_op_count = sum( + [1 if op in self.boundary_op_types else 0 for op in node_ops] + ) + + if boundary_op_count > 3: + can_merge = False + continue + + # Constraint 2: Size limits to avoid overly large regions + # Keep regions manageable for optimization passes + if ( + consumer.get_size() >= self.maximum_sequence_region_size + or common_use_region.get_size() >= self.maximum_sequence_region_size + ): + # One or both regions too large - don't merge + can_merge = False + continue + + # ------------------------------------------------------------- + # Perform Merge + # ------------------------------------------------------------- + # All criteria met: merge consumer into its downstream region + if can_merge: + common_use_region.merge(consumer) + removed_regions.add(consumer.get_id()) + + # ===================================================================== + # PHASE 3: Cleanup and Finalize + # ===================================================================== + # Remove regions that were merged into others + result_regions = [ + region for region in result_regions if region.get_id() not in removed_regions + ] + + # Recompute boundaries for all remaining regions + # (merging may have changed input/output tensors) + for region in result_regions: + self.compute_region_boundaries(region) + + return result_regions + + def _merge_converged_regions(self, root: Region): + """Identify and merge convergence patterns within a region. + + Traverses the region to find divergent nodes and their convergence points, + creating sub-regions that capture divergence→branches→convergence patterns. + Nodes not part of any convergence pattern are left for sequence splitting. + + **Algorithm:** + ``` + 1. Traverse region in data flow order (BFS from inputs) + 2. For each node: + If node is divergent (branches): + Find convergence point + If convergence exists within root: + Create LEAF region with all nodes between divergence and convergence + Mark those nodes as removed (grouped) + 3. Create LEAF region for remaining ungrouped nodes + 4. Return all created regions + ``` + + **Convergence Detection:** + Uses inherited _find_converge_nodes() to identify where branches rejoin. + Only creates convergence regions if the convergence point is within + the root region being refined. + + **Example:** + ``` + Root contains: [A, B, C, D, E, F, G] + + Graph structure: + A → B (divergent) → C, D + C → E + D → E (convergence) + E → F → G + + Result: + - Region1 {B, C, D, E}: Convergence pattern + - Region2 {A, F, G}: Remaining sequence nodes + ``` + + **Use Case:** + Captures patterns like: + - ResNet skip connections (Conv branch + identity → Add) + - Inception modules (multiple parallel conv paths → Concat) + - Attention mechanisms (Q/K/V branches → attention computation) + + **Limitations:** + - Only finds convergence patterns where convergence is in root region + - Nodes can only belong to one convergence pattern (first match wins) + - Uses intersection with root nodes to ensure boundaries respected + + Args: + root: Region to analyze for convergence patterns + + Returns: + List of LEAF regions: + - Some containing convergence patterns (divergence + branches + convergence) + - One containing remaining nodes not part of any pattern + + Note: + This is the "top" of the top-down strategy - identifies high-level + patterns first, then delegates remaining nodes to sequence splitting. + """ + result_regions: list[Region] = [] + removed_nodes: set[int] = set() + queue = deque(root.get_inputs()) + while len(queue) > 0: + tensor_name = queue.popleft() + if tensor_name not in self.tensor_users_map: + continue + consumer_nodes = self.tensor_users_map[tensor_name] + for node_idx in consumer_nodes: + # stop at boundary nodes + if node_idx not in root.get_nodes(): + continue + consumer = self.graph.nodes[node_idx] + for output_tensor in consumer.outputs: + if output_tensor.name not in self.tensor_users_map: + continue + queue.append(output_tensor.name) + # if the node is already in a region, skip + if node_idx in removed_nodes: + continue + if not self._is_node_divergent(node_idx): + continue + converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) + visited_nodes = visited_nodes.intersection(root.get_all_nodes_recursive()) + # if no convergence found, skip + if converge_node_idx is None: + continue + # group converged nodes into a region + if converge_node_idx in root.get_nodes(): + converged_region = self._create_leaf_region(visited_nodes) + result_regions.append(converged_region) + removed_nodes.update(visited_nodes) + continue + # create a leaf region for the remaining nodes + remaining_nodes = root.get_nodes() - removed_nodes + if len(remaining_nodes) > 0: + result_regions.append(self._create_leaf_region(remaining_nodes)) + # compute region boundaries for all regions + for region in result_regions: + self.compute_region_boundaries(region) + return result_regions + + def build_composite_region(self) -> Region: + """Refine a flat region into a hierarchical COMPOSITE region. + + This is the main entry point for top-down refinement. Transforms a flat + LEAF region from RegionPartitioner into a hierarchical structure with + better internal organization. + + **Three-Stage Algorithm:** + + **Stage 1: Merge Converged Patterns** + Identifies divergence→convergence patterns and groups them: + - Finds divergent nodes where computation branches + - Locates convergence points where branches rejoin + - Creates sub-regions for complete convergence patterns + - Leaves ungrouped nodes for next stage + + **Stage 2: Split Sequence Regions** + Takes remaining (ungrouped) nodes and optimizes granularity: + - Splits into fine-grained (single-node) regions + - Merges adjacent regions forming producer-consumer chains + - Respects boundary operations (Conv, Gemm, etc.) + - Limits region sizes to avoid overly large groups + + **Stage 3: Create Composite Wrapper** + Wraps all refined sub-regions into hierarchy: + - Creates COMPOSITE region at same level as input root + - Adds all refined LEAF regions as children + - Computes input/output boundaries for composite + - Returns single COMPOSITE containing hierarchical structure + + **Transformation Example:** + ``` + Input (flat LEAF region from partitioner): + Region(nodes=[0,1,2,3,4,5,6,7,8]) + + After Stage 1 (converged patterns): + [Region{0,1,2}, Region{3,4,5,6,7,8}] # Found one convergence + + After Stage 2 (sequence splitting): + [Region{0,1,2}, Region{3}, Region{4,5,6}, Region{7,8}] + + After Stage 3 (composite wrapping): + COMPOSITE { + LEAF{0,1,2}, # Convergence pattern + LEAF{3}, # Boundary op + LEAF{4,5,6}, # Merged sequence + LEAF{7,8} # Merged sequence + } + ``` + + **Benefits:** + - **Better Granularity:** Not too coarse, not too fine + - **Pattern Recognition:** Convergence patterns kept together + - **Optimization-Friendly:** Boundary ops isolated for targeting + - **Hierarchical:** Enables recursive optimization strategies + + **Invariants Maintained:** + - Total node count unchanged (reorganization only) + - All nodes assigned to exactly one LEAF region + - LEAF regions don't overlap + - Parent-child relationships properly formed + + **Output Format:** + Always returns a single region: + - If input had >1 nodes: COMPOSITE region with LEAF children + - If input had 1 node: That single LEAF region unchanged + + Returns: + COMPOSITE region containing hierarchically organized LEAF sub-regions. + The composite represents the same nodes as input root but with + better internal structure reflecting computation patterns. + + Example: + >>> builder = TopDownRegionBuilder(graph, flat_region, next_id=10) + >>> refined = builder.build_composite_region() + >>> print(f"Type: {refined.get_type()}") # COMPOSITE + >>> print(f"Children: {len(refined.get_children())}") # 4-10 typically + >>> for child in refined.get_children(): + ... print(f" {child.get_id()}: {child.get_size()} nodes") + """ + # merge converged regions into composite regions + self.regions = self._merge_converged_regions(self.root) + # split sequence regions into smaller regions + result_regions: list[Region] = [] + for region in self.regions: + result_regions.extend(self._split_sequence_regions(region)) + for region in result_regions: + self.compute_region_boundaries(region, include_constant=True) + self.regions = result_regions + # merge all regions into a single composite region + if len(self.regions) > 1: + composite = Region( + region_id=self.next_region_id, + level=self.root.level, + region_type=RegionType.COMPOSITE, + ) + self.next_region_id += 1 + self.regions = sorted( + self.regions, key=lambda x: RegionPattern.from_region(x, self.graph).signature + ) + for region in self.regions: + composite.add_child(region) + self.compute_region_boundaries(composite) + self.regions = [composite] + return self.regions[0] + + +class CombinedRegionSearch(RegionSearchBase): + """Two-phase region search combining bottom-up partitioning with top-down refinement. + + This class implements a sophisticated region discovery algorithm that combines two + complementary strategies to create well-formed, hierarchical regions from an ONNX + computation graph: + + **Phase 1: Bottom-Up Partitioning (RegionPartitioner)** + - Traverses the graph from inputs to outputs + - Identifies divergent nodes (nodes with outputs consumed by multiple branches) + - Finds convergence points where divergent branches rejoin + - Creates initial LEAF regions based on divergence/convergence patterns + - Groups linear sequences of non-divergent nodes together + + **Phase 2: Top-Down Refinement (TopDownRegionBuilder)** + - Takes each region from Phase 1 as input + - Identifies and merges converged sub-regions within each region + - Splits long sequences into smaller, more manageable regions + - Creates COMPOSITE regions with hierarchical structure + - Ensures region boundaries align with natural computation patterns + + **Key Features:** + - **Comprehensive Coverage:** Visits all nodes in the graph + - **Hierarchical Structure:** Creates multi-level region hierarchies + - **Pattern Recognition:** Identifies divergence/convergence patterns + - **Boundary Computation:** Automatically computes input/output tensors for each region + - **Quality Metrics:** Provides coverage and node count statistics + + **Region Types Created:** + - LEAF regions: Basic building blocks containing graph nodes + - COMPOSITE regions: Higher-level regions containing child regions + + **Use Cases:** + - Graph partitioning for distributed execution + - Identifying optimization boundaries for quantization/pruning + - Creating sub-graphs for incremental processing + - Analyzing graph structure and dependencies + + **Algorithm Overview:** + 1. Initialize RegionPartitioner for bottom-up search + 2. Partition graph into initial LEAF regions + 3. For each initial region: + a. Merge converged sub-regions + b. Split long sequences into smaller regions + c. Create COMPOSITE region hierarchy + 4. Compute final region boundaries + + **Output:** + A list of COMPOSITE regions that collectively cover the entire graph, + each containing a hierarchical structure of child regions. + + **Example:** + >>> search = CombinedRegionSearch(graph) + >>> regions = search.search_regions() + >>> print(f"Created {len(regions)} top-level regions") + >>> for region in regions: + ... print(f"Region {region.get_id()}: {region.get_size()} nodes") + + **Performance Considerations:** + - Complexity depends on graph structure (divergence/convergence patterns) + - Pre-computes forward-reachable nodes for efficient convergence detection + - Uses BFS for systematic graph traversal + + **Validation:** + - Logs warnings if node counts change during refinement + - Verifies coverage of all nodes in the graph + - Ensures no duplicate nodes across regions + + Attributes: + graph: The ONNX graph to partition (onnx_graphsurgeon.Graph) + regions: List of top-level COMPOSITE regions created by the search + region_partitioner: Internal RegionPartitioner instance + root: Root region containing all graph nodes (inherited from RegionSearchBase) + tensor_users_map: Mapping from tensor names to consuming node indices + forward_reachable_nodes_map: Pre-computed forward reachability information + maximum_sequence_region_size: Maximum nodes per sequence region during merging + """ + + def __init__( + self, + graph: gs.Graph, + maximum_sequence_region_size: int = 10, + minimum_topdown_search_size: int = 10, + ): + """Initialize CombinedRegionSearch for a given ONNX graph. + + Sets up the necessary data structures for two-phase region search: + - Initializes base class with graph and builds root region + - Creates empty regions list for storing results + - Initializes RegionPartitioner for Phase 1 bottom-up search + - Pre-computes tensor users map and forward reachability information + + Args: + graph: The ONNX graph to partition (onnx_graphsurgeon.Graph). + Must be a valid, connected computation graph. + maximum_sequence_region_size: Maximum nodes per sequence region during merging + in Phase 2 refinement (default: 10) + minimum_topdown_search_size: Minimum nodes per region to search during top-down refinement (default: 10) + + Note: + Initialization performs pre-computation that scales with graph size. + For very large graphs, this may take significant time. + + Example: + >>> import onnx_graphsurgeon as gs + >>> import onnx + >>> model = onnx.load("model.onnx") + >>> graph = gs.import_onnx(model) + >>> search = CombinedRegionSearch(graph, maximum_sequence_region_size=10) + """ + super().__init__(graph) + self.regions: list[Region] = [] + self.region_partitioner = RegionPartitioner(graph) + self.minimum_topdown_search_size = minimum_topdown_search_size + self.maximum_sequence_region_size = maximum_sequence_region_size + + def search_regions(self) -> list[Region]: + """Execute two-phase region search to partition the graph into hierarchical regions. + + This is the main entry point for the CombinedRegionSearch algorithm. It performs + a sophisticated two-phase analysis of the computation graph: + + **Phase 1: Bottom-Up Partitioning** + Uses RegionPartitioner to create initial regions by: + - Traversing graph from inputs to outputs + - Identifying divergent nodes (where computation branches) + - Finding convergence points (where branches rejoin) + - Grouping linear sequences of operations + - Creating initial LEAF regions based on these patterns + + **Phase 2: Top-Down Refinement** + For each region from Phase 1, uses TopDownRegionBuilder to: + - Identify and merge converged sub-patterns within the region + - Split long sequences into smaller, more manageable regions + - Create hierarchical COMPOSITE region structures + - Ensure optimal region granularity for optimization + + **Algorithm Steps:** + 1. Initialize RegionPartitioner with the graph + 2. Partition graph into initial regions (Phase 1) + 3. Log partitioning statistics (coverage, region count) + 4. For each initial region: + a. Create TopDownRegionBuilder for refinement + b. Share tensor users map for efficient lookups + c. Build composite region hierarchy (Phase 2) + d. Validate node count consistency + e. Recompute region boundaries + 5. Return final list of refined regions + + **Output Structure:** + Each returned region is typically a COMPOSITE region containing: + - LEAF child regions with actual graph nodes + - Computed input/output tensor boundaries + - Hierarchical structure reflecting computation patterns + + **Quality Metrics Logged:** + - Total regions found: Number of top-level regions created + - Total nodes visited: How many graph nodes were processed + - Coverage percentage: What fraction of the graph was partitioned + + **Validation:** + - Warns if node counts change during refinement (potential bug) + - Ensures all nodes are accounted for + - Verifies region boundary consistency + + Returns: + List of Region objects representing the partitioned graph. + Each region is a COMPOSITE region with a hierarchical structure + of child regions. The regions collectively cover all nodes in + the graph without overlap. + + Raises: + May propagate exceptions from RegionPartitioner or TopDownRegionBuilder + if graph structure is invalid or contains unsupported patterns. + + Example: + >>> search = CombinedRegionSearch(graph) + >>> regions = search.search_regions() + >>> print(f"Graph partitioned into {len(regions)} regions") + >>> # Analyze results + >>> total_nodes = sum(r.get_all_nodes_recursive_count() for r in regions) + >>> print(f"Total nodes in all regions: {total_nodes}") + >>> # Print hierarchical structure + >>> for region in regions: + ... search.print_tree(region) + + Note: + This method modifies self.regions and returns it. Calling this + method multiple times will overwrite previous results. + + See Also: + RegionPartitioner: Phase 1 bottom-up partitioning + TopDownRegionBuilder: Phase 2 top-down refinement + print_tree: Visualize the resulting region hierarchy + """ + # ===================================================================== + # PHASE 1: Bottom-Up Partitioning + # ===================================================================== + # Create a fresh RegionPartitioner instance for this search. + # This performs initial graph analysis including: + # - Building tensor-to-users mapping for tracking data flow + # - Computing forward reachability for convergence detection + logger.info("Phase 1: Bottom-up partitioning") + logger.debug("Initializing RegionPartitioner") + region_partitioner = RegionPartitioner(self.graph) + + # Execute the bottom-up partitioning algorithm. + # This traverses the graph and creates initial LEAF regions based on: + # - Divergence/convergence patterns (where computation branches/rejoins) + # - Linear sequences of non-divergent nodes + # - Graph structure and operation types + self.regions = region_partitioner.partition_graph() + + # ===================================================================== + # Log Phase 1 Results + # ===================================================================== + # Report statistics about the initial partitioning to help understand + # graph structure and verify complete coverage. + coverage_pct = ( + 100 * len(self.region_partitioner.visited_nodes) / len(self.graph.nodes) + if self.graph.nodes + else 0 + ) + logger.info( + f"Phase 1 complete: {len(self.regions)} regions, " + f"{len(self.region_partitioner.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + logger.debug("Proceeding to Phase 2: Top-down refinement") + + # ===================================================================== + # PHASE 2: Top-Down Refinement + # ===================================================================== + # Track the next available region ID to ensure unique IDs across all regions. + # This is important because we'll be creating new regions during refinement. + logger.info("Phase 2: Top-down refinement") + next_region_id = region_partitioner.current_region_id + + # Process each initial region to refine its structure. + # Each region from Phase 1 becomes a root for hierarchical refinement. + refined_count = 0 + skipped_count = 0 + for idx in range(len(self.regions)): + total_nodes = len(self.regions[idx].get_all_nodes_recursive()) + if total_nodes < self.minimum_topdown_search_size: + logger.debug(f"Skipping region {idx}: {total_nodes} nodes (below minimum)") + skipped_count += 1 + continue + + # Create a TopDownRegionBuilder for this specific region. + # This builder will analyze the region and create a hierarchical + # structure of child regions based on internal patterns. + logger.debug(f"Refining region {idx}: {total_nodes} nodes") + region_builder = TopDownRegionBuilder( + self.graph, + self.regions[idx], + next_region_id=next_region_id, + maximum_sequence_region_size=self.maximum_sequence_region_size, + ) + + # Share the tensor users map from Phase 1 to avoid recomputation. + # This map is expensive to build and is shared across all refinements. + region_builder.tensor_users_map = region_partitioner.tensor_users_map + + # Track node count for validation. + # The refinement should reorganize nodes into hierarchies without + # losing or duplicating any nodes. + node_count_before = len(self.regions[idx].get_all_nodes_recursive()) + + # Execute top-down refinement on this region. + # This creates a COMPOSITE region with hierarchical structure: + # 1. Merges converged sub-regions (nodes between divergence/convergence) + # 2. Splits long sequences into smaller regions + # 3. Creates appropriate parent-child relationships + + self.regions[idx] = region_builder.build_composite_region() + + # Validate that refinement preserved all nodes. + # A mismatch indicates a bug in the refinement logic. + node_count_after = len(self.regions[idx].get_all_nodes_recursive()) + if node_count_before != node_count_after: + logger.warning( + f"Node count mismatch in region {idx}: {node_count_before} → {node_count_after}" + ) + + # Recompute region boundaries after refinement. + # The hierarchical structure may have changed the input/output + # tensors at the top level of this region. + region_partitioner.compute_region_boundaries(self.regions[idx]) + + # Update next_region_id for the next iteration. + # Each builder may have created new regions with new IDs. + next_region_id = region_builder.next_region_id + refined_count += 1 + + logger.info(f"Phase 2 complete: refined {refined_count} regions, skipped {skipped_count}") + + # Return the final refined regions + return self.regions + + +# ============================================================================= +# Region Search Inspection Tool +# ============================================================================= + + +def inspect_region_search( + onnx_path: str, + max_sequence_size: int = 10, + include_all_regions: bool = False, +) -> list[Region]: + """Inspect region search results for an ONNX model. + + This function loads an ONNX model, runs CombinedRegionSearch (which performs + both bottom-up partitioning and top-down refinement internally), and prints + detailed information about the discovered regions including their hierarchical + structure. + + **What it does:** + 1. Loads ONNX model and converts to GraphSurgeon format + 2. Creates CombinedRegionSearch instance with specified parameters + 3. Runs two-phase search (partitioning + refinement) via search() + 4. Displays detailed region structure and statistics + 5. Returns the final list of refined regions + + **Output Sections:** + - Initialization: Shows search parameters + - Two-Phase Search: Runs automatically via CombinedRegionSearch.search() + - Detailed Structure: Shows each region's hierarchy and properties + - Summary Statistics: Shows region counts and node coverage + + Args: + onnx_path: Path to the ONNX model file + max_sequence_size: Maximum size for sequence regions during refinement (default: 10) + include_all_regions: Include all regions, even those without major quantizable + operations (Conv, MatMul, etc.). Default: False (skips such regions) + + Returns: + List of discovered and refined regions (LEAF and COMPOSITE) + + Example: + >>> # Inspect model with default settings + >>> regions = inspect_region_search("model.onnx") + >>> print(f"Found {len(regions)} regions") + >>> + >>> # Custom sequence size + >>> regions = inspect_region_search("model.onnx", max_sequence_size=20) + >>> + >>> # Include all regions + >>> regions = inspect_region_search("model.onnx", include_all_regions=True) + """ + # Load ONNX model + logger.info(f"Loading model: {onnx_path}") + onnx_model = onnx.load(onnx_path) + + # Convert to onnx_graphsurgeon Graph + graph = gs.import_onnx(onnx_model) + graph.cleanup().toposort() + logger.info( + f"Loaded graph: {len(graph.nodes)} nodes, {len(graph.inputs)} inputs, {len(graph.outputs)} outputs" + ) + + # Initialize CombinedRegionSearch (contains RegionPartitioner internally) + logger.debug( + f"Search parameters: max_steps={DEFAULT_MAX_STEPS}, max_sequence_size={max_sequence_size}" + ) + + combined_search = CombinedRegionSearch(graph, maximum_sequence_region_size=max_sequence_size) + + # Run complete two-phase region search + logger.info("Running region search") + regions = combined_search.search_regions() + + # Show detailed region structure + logger.info("Analyzing region structure") + all_regions = [] + for i, region in enumerate(regions): + for child in region.get_children(): + if not include_all_regions and not has_quantizable_operations(child, graph): + region.remove_child(child) + if not include_all_regions and not has_quantizable_operations(region, graph): + logger.debug(f"Filtered out region {i} (no quantizable operations)") + continue + logger.debug( + f"Region {i}: {region.get_type().value}, {len(region.get_all_nodes_recursive())} nodes, " + f"{len(region.inputs)} inputs, {len(region.outputs)} outputs" + ) + all_regions.append(region) + if region.get_type() == RegionType.COMPOSITE: + logger.debug(f" {len(region.get_children())} child regions") + all_regions.extend(region.get_children()) + combined_search.print_tree(region, indent=2) + + # Summary statistics + leaf_regions = sum(1 for r in all_regions if r.get_type() == RegionType.LEAF) + composite_regions = sum(1 for r in all_regions if r.get_type() == RegionType.COMPOSITE) + + all_nodes = set() + for region in all_regions: + all_nodes.update(region.get_all_nodes_recursive()) + total_nodes = len(all_nodes) + coverage_pct = 100 * total_nodes / len(graph.nodes) if graph.nodes else 0 + + logger.info( + f"Summary: {len(all_regions)} regions ({leaf_regions} LEAF, {composite_regions} COMPOSITE), " + f"{total_nodes}/{len(graph.nodes)} nodes ({coverage_pct:.1f}%)" + ) + + # Print histogram of region sizes + region_sizes = [ + len(r.get_all_nodes_recursive()) for r in all_regions if r.get_type() == RegionType.LEAF + ] + + if region_sizes: + min_size = min(region_sizes) + max_size = max(region_sizes) + avg_size = sum(region_sizes) / len(region_sizes) + + logger.info(f"LEAF region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") + + # Create histogram bins + size_counts = Counter(region_sizes) + + # Display histogram + logger.debug("Size distribution:") + for size in sorted(size_counts.keys()): + count = size_counts[size] + bar = "█" * min(count, 50) # Cap bar length at 50 + logger.debug(f" {size:4d} nodes: {bar} ({count} regions)") + + return regions + + +def main(): + """Command-line entry point for region search inspection.""" + parser = argparse.ArgumentParser( + prog="modelopt.onnx.quantization.autotune.region_search", + description="Inspect region search results for ONNX models", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Basic inspection + python -m modelopt.onnx.quantization.autotune.region_search --model model.onnx + + # Verbose mode for debug logging + python -m modelopt.onnx.quantization.autotune.region_search \\ + --model model.onnx --verbose + + # Custom maximum sequence size + python -m modelopt.onnx.quantization.autotune.region_search \\ + --model model.onnx --max-sequence-size 20 + """, + ) + + parser.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file") + parser.add_argument( + "--max-sequence-size", + type=int, + default=10, + help="Maximum size for sequence regions during refinement (default: 10)", + ) + parser.add_argument( + "--include-all-regions", + action="store_true", + help="Include all regions, even those without major quantizable operations. " + "Default: False (skips such regions)", + ) + parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug logging") + + args = parser.parse_args() + + # Configure logging + log_level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") + logger.setLevel(log_level) + + # Run inspection + try: + regions = inspect_region_search( + onnx_path=args.model, + max_sequence_size=args.max_sequence_size, + include_all_regions=args.include_all_regions, + ) + logger.info(f"✓ Inspection complete: {len(regions)} top-level regions discovered") + return 0 + except Exception as e: + logger.error(f"Inspection failed: {e}", exc_info=args.verbose) + return 1 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df..a30a113ec 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,23 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_consumer_map[tensor.name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index 026b8d062..f533a8ede 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1035,3 +1035,47 @@ def cast_initializer_to_dtype( input_onnx = onnx.numpy_helper.from_array(input, input_name) input_onnx.data_type = onnx_dtype_map[dtype] initializer_map[input_name].CopyFrom(input_onnx) + + +def get_quantized_tensors(onnx_model: onnx.ModelProto) -> set[str]: + """Get the names of all quantized tensors from an ONNX model. + + This function identifies all QuantizeLinear nodes in the ONNX model + and extracts the names of tensors being quantized (the first input of + each QuantizeLinear node, excluding scale and zero-point inputs). + + Args: + onnx_model: ONNX model protobuf to analyze + + Returns: + Set of tensor names that are inputs to QuantizeLinear nodes + (i.e., the tensors being quantized) + + Example: + >>> import onnx + >>> from modelopt.onnx.quantization.qdq_utils import get_quantized_tensors + >>> + >>> # Load a quantized model + >>> model = onnx.load("quantized_model.onnx") + >>> + >>> # Get all quantized tensor names + >>> quantized_tensors = get_quantized_tensors(model) + >>> print(f"Found {len(quantized_tensors)} quantized tensors") + >>> + >>> # Use with autotuner to import insertion points + >>> from modelopt.onnx.quantization.autotune import QDQAutotuner + >>> autotuner = QDQAutotuner(new_model) + >>> autotuner.initialize() + >>> autotuner.import_insertion_points(quantized_tensors) + """ + quantized_tensors = set() + + for node in onnx_model.graph.node: + if node.op_type == "DequantizeLinear": + # First input is the tensor being quantized + # (inputs[1] is scale, inputs[2] is zero-point) + if node.input and len(node.input) > 0: + quantized_tensors.add(node.input[0]) + + logger.debug(f"Found {len(quantized_tensors)} quantized tensors in ONNX model") + return quantized_tensors diff --git a/tests/unit/onnx/quantization/autotune/test_pattern_cache.py b/tests/unit/onnx/quantization/autotune/test_pattern_cache.py new file mode 100644 index 000000000..bd9b67f90 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_pattern_cache.py @@ -0,0 +1,237 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for PatternCache functionality in the autotuner. + +Tests pattern cache creation, serialization, and scheme management. +""" + +import os +import sys +import tempfile +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from modelopt.onnx.quantization.autotune.common import ( + InsertionScheme, + NodeInputInsertionPoint, + PatternCache, + PatternSchemes, +) +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + +class TestPatternCache(unittest.TestCase): + """Test PatternCache functionality.""" + + @staticmethod + def _create_test_pattern(signature: str, size: int = 2): + """Create a test RegionPattern.""" + return RegionPattern(signature=signature, size=size) + + def test_empty_cache_creation(self): + """Test creating an empty PatternCache.""" + cache = PatternCache() + + assert len(cache.pattern_schemes) == 0 + assert cache.pattern_schemes is not None + + def test_add_pattern_schemes(self): + """Test adding pattern schemes to cache.""" + cache = PatternCache() + + # Create a pattern scheme + pattern = self._create_test_pattern("Conv->Relu") + ps = PatternSchemes(pattern=pattern) + scheme = InsertionScheme() + scheme.latency_ms = 10.0 + ps.schemes.append(scheme) + + cache.add_pattern_schemes(ps) + + assert len(cache.pattern_schemes) == 1 + assert cache.pattern_schemes[0].pattern_signature == "Conv->Relu" + + def test_multiple_patterns(self): + """Test cache with multiple pattern schemes.""" + cache = PatternCache() + + # Add multiple patterns + pattern_sigs = ["Conv->Relu", "Gemm->Relu", "Conv->Add->Relu"] + for pattern_sig in pattern_sigs: + pattern = self._create_test_pattern(pattern_sig) + ps = PatternSchemes(pattern=pattern) + scheme = InsertionScheme() + scheme.latency_ms = 10.0 + len(pattern_sig) + ps.schemes.append(scheme) + cache.add_pattern_schemes(ps) + + assert len(cache.pattern_schemes) == 3 + found_patterns = [ps.pattern_signature for ps in cache.pattern_schemes] + for pattern_sig in pattern_sigs: + assert pattern_sig in found_patterns + + def test_serialization_empty(self): + """Test serialization of empty cache.""" + cache = PatternCache() + + data = cache.to_dict() + assert "pattern_schemes" in data + assert len(data["pattern_schemes"]) == 0 + + restored = PatternCache.from_dict(data) + assert len(restored.pattern_schemes) == 0 + + def test_serialization_with_data(self): + """Test serialization with pattern schemes.""" + # Create cache with minimum_distance=0 to keep both schemes + cache = PatternCache(minimum_distance=0) + + # Add a pattern scheme + pattern = self._create_test_pattern("Conv->Relu") + ps = PatternSchemes(pattern=pattern) + + # Create schemes that are sufficiently different (distance >= 4) + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme1.latency_ms = 10.0 + ps.schemes.append(scheme1) + + scheme2 = InsertionScheme() + scheme2.node_inputs = [ + NodeInputInsertionPoint(0, 0), + NodeInputInsertionPoint(1, 0), + NodeInputInsertionPoint(2, 0), + NodeInputInsertionPoint(3, 0), + NodeInputInsertionPoint(4, 0), # 5 total points, diff = 4 from scheme1 + ] + scheme2.latency_ms = 12.0 + ps.schemes.append(scheme2) + + cache.add_pattern_schemes(ps) + + # Serialize and restore + data = cache.to_dict() + restored = PatternCache.from_dict(data) + + assert len(restored.pattern_schemes) == 1 + + restored_ps = restored.pattern_schemes[0] + assert restored_ps.pattern_signature == "Conv->Relu" + assert len(restored_ps.schemes) == 2 + assert restored_ps.best_scheme_index == 0 + assert restored_ps.schemes[0].latency_ms == 10.0 + + def test_yaml_round_trip(self): + """Test saving and loading cache as YAML.""" + cache = PatternCache() + + # Add a pattern scheme + pattern = self._create_test_pattern("Gemm->Relu") + ps = PatternSchemes(pattern=pattern) + scheme = InsertionScheme() + scheme.latency_ms = 15.0 + ps.schemes.append(scheme) + cache.add_pattern_schemes(ps) + + # Save to YAML + with tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", delete=False) as f: + yaml_path = f.name + + try: + cache.save(yaml_path) + + # Load from YAML + restored = PatternCache.load(yaml_path) + + assert len(restored.pattern_schemes) == 1 + assert restored.pattern_schemes[0].pattern_signature == "Gemm->Relu" + assert restored.pattern_schemes[0].schemes[0].latency_ms == 15.0 + finally: + if os.path.exists(yaml_path): + os.unlink(yaml_path) + + def test_update_cache(self): + """Test updating existing pattern in cache (merges schemes).""" + # Use minimum_distance=0 to keep all schemes + cache = PatternCache(minimum_distance=0) + + # Add initial pattern + pattern1 = self._create_test_pattern("Conv->Relu") + ps1 = PatternSchemes(pattern=pattern1) + scheme1 = InsertionScheme() + scheme1.latency_ms = 10.0 + ps1.schemes.append(scheme1) + cache.add_pattern_schemes(ps1) + + # Update with new scheme for same pattern + pattern2 = self._create_test_pattern("Conv->Relu") + ps2 = PatternSchemes(pattern=pattern2) + scheme2 = InsertionScheme() + scheme2.latency_ms = 8.0 # Better performance + scheme2.node_inputs = [NodeInputInsertionPoint(0, 0)] # Make it different + ps2.schemes.append(scheme2) + cache.add_pattern_schemes(ps2) + + # Verify merge (should have both schemes now) + assert len(cache.pattern_schemes) == 1 + conv_relu_ps = cache.pattern_schemes[0] + assert conv_relu_ps.pattern_signature == "Conv->Relu" + assert len(conv_relu_ps.schemes) == 2 # Merged + # Best scheme should be the one with lowest latency + assert conv_relu_ps.best_scheme.latency_ms == 8.0 + + def test_get_best_scheme(self): + """Test retrieving best scheme for a pattern.""" + # Use minimum_distance=0 to keep all different schemes + cache = PatternCache(minimum_distance=0) + + pattern = self._create_test_pattern("Conv->Relu") + ps = PatternSchemes(pattern=pattern) + + # Add multiple schemes with different insertion points + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme1.latency_ms = 12.0 + ps.schemes.append(scheme1) + + scheme2 = InsertionScheme() + scheme2.node_inputs = [NodeInputInsertionPoint(1, 0)] # Different node + scheme2.latency_ms = 8.0 + ps.schemes.append(scheme2) + + scheme3 = InsertionScheme() + scheme3.node_inputs = [NodeInputInsertionPoint(2, 0)] # Different node + scheme3.latency_ms = 10.0 + ps.schemes.append(scheme3) + + cache.add_pattern_schemes(ps) + + # Verify best scheme retrieval (automatically computed) + conv_relu_ps = cache.pattern_schemes[0] + assert conv_relu_ps.pattern_signature == "Conv->Relu" + assert len(conv_relu_ps.schemes) == 3 # All 3 kept + + # Verify best scheme has lowest latency (cache may reorder schemes) + best = conv_relu_ps.best_scheme + assert best is not None + assert best.latency_ms == 8.0 + + # Verify all three latencies are present + latencies = sorted([s.latency_ms for s in conv_relu_ps.schemes]) + assert latencies == [8.0, 10.0, 12.0] diff --git a/tests/unit/onnx/quantization/autotune/test_region_pattern.py b/tests/unit/onnx/quantization/autotune/test_region_pattern.py new file mode 100644 index 000000000..2a39f4cae --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_pattern.py @@ -0,0 +1,410 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for RegionPattern functionality in the autotuner. + +Tests pattern generation, matching, and tree visualization. +""" + +import os +import sys +import unittest + +import numpy as np +import onnx_graphsurgeon as gs + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + +class TestRegionPattern(unittest.TestCase): + """Test RegionPattern functionality.""" + + # ========================================================================= + # Helper Methods + # ========================================================================= + + @staticmethod + def _create_simple_graph(): + """Create a simple Conv->Relu graph for testing. + + Graph structure: + input -> Conv -> Relu -> output + """ + # Create inputs and outputs + inp = gs.Variable(name="input", dtype=np.float32, shape=[1, 3, 224, 224]) + conv_out = gs.Variable(name="conv_out", dtype=np.float32) + relu_out = gs.Variable(name="output", dtype=np.float32) + + # Create weights + conv_weight = gs.Constant( + name="conv_weight", values=np.ones((64, 3, 3, 3), dtype=np.float32) + ) + + # Create nodes + conv = gs.Node( + name="Conv_0", + op="Conv", + inputs=[inp, conv_weight], + outputs=[conv_out], + attrs={"kernel_shape": [3, 3], "strides": [1, 1], "pads": [1, 1, 1, 1]}, + ) + relu = gs.Node( + name="Relu_0", + op="Relu", + inputs=[conv_out], + outputs=[relu_out], + ) + + # Create graph + graph = gs.Graph( + nodes=[conv, relu], + inputs=[inp], + outputs=[relu_out], + opset=13, + ) + return graph + + @staticmethod + def _create_hierarchical_graph(): + """Create a hierarchical graph with composite regions. + + Graph structure: + input -> Conv -> Relu -> Add -> MatMul -> Relu -> output + ^ + | + other_input + + Region structure: + ROOT + ├── COMPOSITE (Conv->Relu->Add) + │ ├── LEAF (Conv->Relu) + │ └── LEAF (Add) + └── COMPOSITE (MatMul->Relu) + └── LEAF (MatMul->Relu) + """ + # Create inputs and intermediate tensors + inp = gs.Variable(name="input", dtype=np.float32, shape=[1, 64, 64, 64]) + other_inp = gs.Variable(name="other_input", dtype=np.float32, shape=[1, 64, 64, 64]) + conv_out = gs.Variable(name="conv_out", dtype=np.float32) + relu1_out = gs.Variable(name="relu1_out", dtype=np.float32) + add_out = gs.Variable(name="add_out", dtype=np.float32) + matmul_out = gs.Variable(name="matmul_out", dtype=np.float32) + output = gs.Variable(name="output", dtype=np.float32) + + # Create constants + conv_weight = gs.Constant( + name="conv_weight", values=np.ones((64, 64, 1, 1), dtype=np.float32) + ) + matmul_weight = gs.Constant( + name="matmul_weight", values=np.ones((64, 64), dtype=np.float32) + ) + + # Create nodes (order matters for node indices) + conv = gs.Node( + name="Conv_0", + op="Conv", + inputs=[inp, conv_weight], + outputs=[conv_out], + attrs={"kernel_shape": [1, 1]}, + ) # Node 0 + relu1 = gs.Node(name="Relu_0", op="Relu", inputs=[conv_out], outputs=[relu1_out]) # Node 1 + add = gs.Node( + name="Add_0", op="Add", inputs=[relu1_out, other_inp], outputs=[add_out] + ) # Node 2 + matmul = gs.Node( + name="MatMul_0", op="MatMul", inputs=[add_out, matmul_weight], outputs=[matmul_out] + ) # Node 3 + relu2 = gs.Node(name="Relu_1", op="Relu", inputs=[matmul_out], outputs=[output]) # Node 4 + + # Create graph + graph = gs.Graph( + nodes=[conv, relu1, add, matmul, relu2], + inputs=[inp, other_inp], + outputs=[output], + opset=13, + ) + return graph + + @staticmethod + def _create_test_region( + region_id: int, level: int, region_type: RegionType, node_indices: list[int] | None = None + ) -> Region: + """Create a test region.""" + region = Region(region_id, level, region_type) + if node_indices: + region.add_nodes(node_indices) + return region + + # ========================================================================= + # Test Cases + # ========================================================================= + + def test_pattern_creation(self): + """Test basic RegionPattern creation.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + + assert pattern.signature == "Conv->Relu" + assert pattern.size == 2 + assert not pattern.is_empty + assert pattern.is_leaf + assert not pattern.is_composite + + def test_pattern_equality(self): + """Test RegionPattern equality based on signature.""" + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) # Different size + pattern3 = RegionPattern(signature="Gemm->Relu", size=2) + + # Same signature = equal (size doesn't affect equality) + assert pattern1 == pattern2 + # Different signature = not equal + assert pattern1 != pattern3 + + def test_pattern_hash(self): + """Test RegionPattern hashing for use in dicts/sets.""" + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + + # Same signature = same hash + assert hash(pattern1) == hash(pattern2) + + # Can be used as dict keys + pattern_dict = {pattern1: "scheme1"} + assert pattern_dict[pattern2] == "scheme1" # pattern2 finds pattern1's entry + + def test_pattern_from_simple_region(self): + """Test pattern computation from a simple region.""" + graph = self._create_simple_graph() + + # Create a leaf region with Conv and Relu nodes + region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + + pattern = RegionPattern.from_region(region, graph) + + # Should capture both operations + assert "Conv" in pattern.signature + assert "Relu" in pattern.signature + assert pattern.size == 2 + assert pattern.is_leaf + + def test_pattern_from_composite_region(self): + """Test pattern computation from a composite region with children.""" + graph = self._create_hierarchical_graph() + + # Create leaf regions + leaf1 = self._create_test_region( + region_id=1, + level=0, + region_type=RegionType.LEAF, + node_indices=[0, 1], # Conv, Relu + ) + leaf2 = self._create_test_region( + region_id=2, + level=0, + region_type=RegionType.LEAF, + node_indices=[2], # Add + ) + + # Create composite region + composite = self._create_test_region( + region_id=3, level=1, region_type=RegionType.COMPOSITE, node_indices=[] + ) + composite.add_child(leaf1) + composite.add_child(leaf2) + + pattern = RegionPattern.from_region(composite, graph) + + assert pattern.is_composite + assert "COMPOSITE" in pattern.signature + assert pattern.size == 3 # Total nodes in region hierarchy + + def test_pattern_get_hash(self): + """Test cryptographic hash generation.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + hash_val = pattern.get_hash() + + # Hash should be 32 hex characters (128-bit truncated SHA-256) + assert len(hash_val) == 32 + assert all(c in "0123456789abcdef" for c in hash_val) + + # Same signature = same hash + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + assert pattern.get_hash() == pattern2.get_hash() + + def test_pattern_get_short_signature(self): + """Test signature truncation.""" + long_sig = "COMPOSITE(" + "Conv->Relu->" * 20 + "Output)" + pattern = RegionPattern(signature=long_sig, size=20) + + short_sig = pattern.get_short_signature(max_length=50) + assert len(short_sig) == 50 + assert short_sig.endswith("...") + + # Short signature stays unchanged + short_pattern = RegionPattern(signature="Conv", size=1) + assert short_pattern.get_short_signature(max_length=50) == "Conv" + + def test_print_tree(self): + """Test format_tree to visualize region structure. + + This test demonstrates how to use format_tree to display + the hierarchical structure of regions and their patterns. + """ + graph = self._create_hierarchical_graph() + + # Build a hierarchical region structure: + # ROOT (level=2) + # ├── COMPOSITE (level=1) [Conv->Relu + Add] + # │ ├── LEAF (level=0) [Conv, Relu - nodes 0,1] + # │ └── LEAF (level=0) [Add - node 2] + # └── LEAF (level=0) [MatMul, Relu - nodes 3,4] + + # Create leaf regions + leaf_conv_relu = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + leaf_add = self._create_test_region( + region_id=2, level=0, region_type=RegionType.LEAF, node_indices=[2] + ) + leaf_matmul_relu = self._create_test_region( + region_id=3, level=0, region_type=RegionType.LEAF, node_indices=[3, 4] + ) + + # Create composite region containing conv_relu and add + composite = self._create_test_region( + region_id=4, level=1, region_type=RegionType.COMPOSITE, node_indices=[] + ) + composite.add_child(leaf_conv_relu) + composite.add_child(leaf_add) + + # Create root region containing everything + root = self._create_test_region( + region_id=5, level=2, region_type=RegionType.ROOT, node_indices=[] + ) + root.add_child(composite) + root.add_child(leaf_matmul_relu) + + # Generate pattern for root and print tree + root_pattern = RegionPattern.from_region(root, graph) + tree_output = root_pattern.format_tree(root, graph) + + print("\n" + "=" * 60) + print("Region Tree Structure:") + print("=" * 60) + print(tree_output) + print("=" * 60) + + # Verify tree output contains expected elements + assert "Region 5" in tree_output # Root + assert "Region 4" in tree_output # Composite + assert "Region 1" in tree_output # Leaf conv_relu + assert "Region 2" in tree_output # Leaf add + assert "Region 3" in tree_output # Leaf matmul_relu + + # Verify indentation shows hierarchy + lines = tree_output.strip().split("\n") + assert len(lines) >= 3 # At least root + children + + # Root should have no indentation + assert lines[0].startswith("Region 5") + + # Children should be indented + indented_lines = [line for line in lines if line.startswith(" ")] + assert len(indented_lines) > 0 + + def test_pattern_matches_pattern(self): + """Test pattern-to-pattern matching.""" + pattern1 = RegionPattern(signature="Conv->Relu", size=2) + pattern2 = RegionPattern(signature="Conv->Relu", size=5) + pattern3 = RegionPattern(signature="Gemm->Relu", size=2) + + assert pattern1.matches(pattern2) # Same signature + assert not pattern1.matches(pattern3) # Different signature + + def test_pattern_matches_region(self): + """Test pattern-to-region matching.""" + graph = self._create_simple_graph() + + # Create region + region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[0, 1] + ) + + # Create pattern from region + pattern = RegionPattern.from_region(region, graph) + + # Match should return node IDs + node_ids = pattern.matches(region, graph) + assert node_ids is not None + assert set(node_ids) == {0, 1} + + def test_empty_region_pattern(self): + """Test pattern for empty region.""" + graph = self._create_simple_graph() + + # Create empty region + empty_region = self._create_test_region( + region_id=1, level=0, region_type=RegionType.LEAF, node_indices=[] + ) + + pattern = RegionPattern.from_region(empty_region, graph) + + assert pattern.is_empty + assert pattern.signature == "EMPTY" + assert pattern.size == 0 + + def test_symmetric_operation_signature(self): + """Test that symmetric operations (Add, Mul) have consistent signatures.""" + # Create two graphs with Add inputs in different order + inp1 = gs.Variable(name="input1", dtype=np.float32, shape=[1, 64]) + inp2 = gs.Variable(name="input2", dtype=np.float32, shape=[1, 64]) + out = gs.Variable(name="output", dtype=np.float32) + + # Graph 1: Add(inp1, inp2) + add1 = gs.Node(name="Add_0", op="Add", inputs=[inp1, inp2], outputs=[out]) + graph1 = gs.Graph(nodes=[add1], inputs=[inp1, inp2], outputs=[out], opset=13) + + # Graph 2: Add(inp2, inp1) - reversed inputs + add2 = gs.Node(name="Add_0", op="Add", inputs=[inp2, inp1], outputs=[out]) + graph2 = gs.Graph(nodes=[add2], inputs=[inp1, inp2], outputs=[out], opset=13) + + # Create regions + region1 = self._create_test_region(1, 0, RegionType.LEAF, [0]) + region2 = self._create_test_region(1, 0, RegionType.LEAF, [0]) + + pattern1 = RegionPattern.from_region(region1, graph1) + pattern2 = RegionPattern.from_region(region2, graph2) + + # Patterns should be equal regardless of input order + assert pattern1 == pattern2 + + def test_pattern_repr_and_str(self): + """Test string representations.""" + pattern = RegionPattern(signature="Conv->Relu", size=2) + + # str() shows just signature + assert str(pattern) == "Conv->Relu" + + # repr() shows full info + assert "RegionPattern" in repr(pattern) + assert "Conv->Relu" in repr(pattern) + assert "size=2" in repr(pattern) diff --git a/tests/unit/onnx/quantization/autotune/test_region_search.py b/tests/unit/onnx/quantization/autotune/test_region_search.py new file mode 100644 index 000000000..d510cc277 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region_search.py @@ -0,0 +1,420 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for region search algorithms. + +Tests CombinedRegionSearch, RegionPartitioner, and TopDownRegionBuilder. +Note: Comprehensive integration tests with real ONNX graphs should be in separate integration test files. +""" + +import io +import os +import sys +import unittest + +# Add parent directory to path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import onnx +import onnx_graphsurgeon as gs +from onnx import helper + +from modelopt.onnx.quantization.autotune.common import Region, RegionType +from modelopt.onnx.quantization.autotune.region_search import ( + CombinedRegionSearch, + RegionPartitioner, + TopDownRegionBuilder, +) + + +def create_simple_linear_graph(): + """ + Create a simple linear graph: Input -> Conv -> Relu -> Output. + + This is the simplest possible graph for testing region discovery. + """ + # Input + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + + # Output + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + # Conv node + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + + # Relu node + relu_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["output"], name="relu") + + # Create graph + graph = helper.make_graph( + [conv_node, relu_node], + "simple_linear", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + # Create model + model = helper.make_model(graph, producer_name="test") + + # Convert to GraphSurgeon + gs_graph = gs.import_onnx(model) + return gs_graph + + +def create_divergent_graph(): + """ + Create a graph with divergence: Input -> Conv -> [Relu1, Relu2] -> Add -> Output. + + Tests divergence/convergence pattern detection. + """ + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224]) + output_tensor = helper.make_tensor_value_info( + "output", onnx.TensorProto.FLOAT, [1, 64, 224, 224] + ) + + conv_node = helper.make_node( + "Conv", inputs=["input", "conv_weight"], outputs=["conv_out"], name="conv" + ) + relu1_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu1_out"], name="relu1") + relu2_node = helper.make_node("Relu", inputs=["conv_out"], outputs=["relu2_out"], name="relu2") + add_node = helper.make_node( + "Add", inputs=["relu1_out", "relu2_out"], outputs=["output"], name="add" + ) + + graph = helper.make_graph( + [conv_node, relu1_node, relu2_node, add_node], + "divergent", + [input_tensor], + [output_tensor], + initializer=[ + helper.make_tensor( + "conv_weight", onnx.TensorProto.FLOAT, [64, 3, 3, 3], [0.1] * (64 * 3 * 3 * 3) + ) + ], + ) + + model = helper.make_model(graph, producer_name="test") + gs_graph = gs.import_onnx(model) + return gs_graph + + +class TestRegionPartitioner(unittest.TestCase): + """Test RegionPartitioner basic functionality.""" + + def test_creation_linear_graph(self): + """Test creating RegionPartitioner with a simple linear graph.""" + graph = create_simple_linear_graph() + partitioner = RegionPartitioner(graph) + + assert partitioner is not None + assert partitioner.graph == graph + + def test_partition_linear_graph(self): + """Test partitioning a simple linear graph.""" + graph = create_simple_linear_graph() + partitioner = RegionPartitioner(graph) + + regions = partitioner.partition_graph() + + # Should create at least one region + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + def test_partition_divergent_graph(self): + """Test partitioning a divergent graph.""" + graph = create_divergent_graph() + partitioner = RegionPartitioner(graph) + + regions = partitioner.partition_graph() + + # Should create regions covering all nodes + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + +class TestTopDownRegionBuilder(unittest.TestCase): + """Test TopDownRegionBuilder basic functionality.""" + + def test_creation(self): + """Test creating TopDownRegionBuilder.""" + graph = create_simple_linear_graph() + + # Create a root region with all nodes + root = Region(region_id=0, level=0, region_type=RegionType.LEAF) + for idx in range(len(graph.nodes)): + root.add_node(idx) + + builder = TopDownRegionBuilder(graph, root) + + assert builder is not None + assert builder.graph == graph + + def test_build_composite_region(self): + """Test building a composite region.""" + graph = create_simple_linear_graph() + + # First partition to get initial regions + partitioner = RegionPartitioner(graph) + initial_regions = partitioner.partition_graph() + + if len(initial_regions) > 0: + # Use first region as root for top-down building + root_region = initial_regions[0] + + builder = TopDownRegionBuilder(graph, root_region, next_region_id=100) + + # Build composite region (may return LEAF or COMPOSITE depending on structure) + composite = builder.build_composite_region() + + assert composite is not None + # Region type depends on whether refinement created internal structure + # For simple linear graphs, may stay as LEAF + assert composite.type in [RegionType.LEAF, RegionType.COMPOSITE] + else: + self.skipTest("No initial regions to refine") + + +class TestCombinedRegionSearch(unittest.TestCase): + """Test CombinedRegionSearch two-phase algorithm.""" + + def test_creation(self): + """Test creating CombinedRegionSearch.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + + assert search is not None + assert search.graph == graph + + def test_search_linear_graph(self): + """Test searching regions in a simple linear graph.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + # Each region should have valid inputs/outputs + for region in regions: + assert region.inputs is not None + assert region.outputs is not None + + def test_search_divergent_graph(self): + """Test searching regions in a divergent graph.""" + graph = create_divergent_graph() + search = CombinedRegionSearch(graph) + + regions = search.search_regions() + + # Should create regions + assert len(regions) > 0 + + # Check that regions cover most nodes (ONNX GS may add Constant nodes that aren't partitioned) + total_nodes = sum(len(r.get_region_nodes_and_descendants()) for r in regions) + assert total_nodes > 0 + assert total_nodes <= len(graph.nodes) + + def test_region_hierarchy(self): + """Test that regions have proper hierarchical structure.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + + regions = search.search_regions() + + # Check that regions have children (hierarchical structure) + for region in regions: + if region.type == RegionType.COMPOSITE: + assert len(region.get_children()) > 0 + + # Verify parent-child relationships + for child in region.get_children(): + assert child.parent == region + + def test_parameters(self): + """Test CombinedRegionSearch with custom parameters.""" + graph = create_simple_linear_graph() + + # Test with different parameter values + search = CombinedRegionSearch( + graph, maximum_sequence_region_size=5, minimum_topdown_search_size=5 + ) + + regions = search.search_regions() + + assert len(regions) > 0 + + +class TestPrintTree(unittest.TestCase): + """Test print_tree functionality.""" + + def test_print_tree_basic(self): + """Test basic print_tree output.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + # Capture output to StringIO + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should contain region information + assert "Region" in result + assert "Level" in result + assert "Type:" in result + + def test_print_tree_contains_node_info(self): + """Test that print_tree output contains node information.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should contain node counts + assert "Direct nodes:" in result + assert "Total nodes (recursive):" in result + assert "Children:" in result + + def test_print_tree_contains_io_info(self): + """Test that print_tree output contains input/output tensor info.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should contain I/O information + assert "Inputs:" in result + assert "Outputs:" in result + assert "tensors" in result + + def test_print_tree_divergent_graph(self): + """Test print_tree on a divergent graph with more complex structure.""" + graph = create_divergent_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + output = io.StringIO() + search.print_tree(file=output) + + result = output.getvalue() + + # Should produce valid output + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_max_nodes_to_show(self): + """Test print_tree with custom max_nodes_to_show parameter.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + search.search_regions() + + # Test with different max_nodes_to_show values + output1 = io.StringIO() + search.print_tree(max_nodes_to_show=1, file=output1) + + output2 = io.StringIO() + search.print_tree(max_nodes_to_show=10, file=output2) + + # Both should produce output + assert len(output1.getvalue()) > 0 + assert len(output2.getvalue()) > 0 + + def test_print_tree_specific_region(self): + """Test print_tree with a specific region instead of root.""" + graph = create_simple_linear_graph() + search = CombinedRegionSearch(graph) + regions = search.search_regions() + + if len(regions) > 0: + # Print a specific region + output = io.StringIO() + search.print_tree(region=regions[0], file=output) + + result = output.getvalue() + assert "Region" in result + assert f"Region {regions[0].id}" in result + + def test_print_tree_partitioner(self): + """Test print_tree on RegionPartitioner.""" + graph = create_simple_linear_graph() + partitioner = RegionPartitioner(graph) + partitioner.partition_graph() + + output = io.StringIO() + partitioner.print_tree(file=output) + + result = output.getvalue() + assert "Region" in result + assert len(result) > 0 + + def test_print_tree_top_down_builder(self): + """Test print_tree on TopDownRegionBuilder.""" + graph = create_simple_linear_graph() + + # Create a root region with all nodes + root = Region(region_id=0, level=0, region_type=RegionType.LEAF) + for idx in range(len(graph.nodes)): + root.add_node(idx) + + builder = TopDownRegionBuilder(graph, root) + # Compute region I/O boundaries before building + builder.compute_region_boundaries(root) + builder.build_composite_region() + + output = io.StringIO() + builder.print_tree(file=output) + + result = output.getvalue() + print("\n" + "=" * 60) + print("Region Tree Structure:") + print("=" * 60) + print(result) + print("=" * 60) + + assert "Region" in result + assert len(result) > 0 From 559d12ca9e086b7d66c63055b52542d4c27d6c60 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 26 Jan 2026 05:54:47 +0000 Subject: [PATCH 2/5] Part-2 recent refactor changes Signed-off-by: Will Guo --- modelopt/onnx/op_types.py | 2 +- .../onnx/quantization/autotune/__init__.py | 88 +- modelopt/onnx/quantization/autotune/common.py | 746 +------ .../quantization/autotune/region_inspect.py | 10 +- .../quantization/autotune/region_pattern.py | 591 +----- .../quantization/autotune/region_search.py | 1861 ++--------------- modelopt/onnx/quantization/graph_utils.py | 3 +- modelopt/onnx/quantization/qdq_utils.py | 31 +- .../autotune/test_region_search.py | 7 +- 9 files changed, 393 insertions(+), 2946 deletions(-) diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index 0352e7106..cd32ba17a 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str): ] -def get_copy_ops(): +def get_copy_ops() -> list[str]: """Returns list of copy operators.""" return [ "Flatten", diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py index a65b2ccba..3aa63c94c 100644 --- a/modelopt/onnx/quantization/autotune/__init__.py +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,77 +18,6 @@ This package provides automated optimization of Quantize/Dequantize (Q/DQ) node placement in ONNX computation graphs to minimize TensorRT inference latency. It uses pattern-based region analysis to efficiently explore and optimize Q/DQ insertion strategies. - -**Key Features:** - -- **Automated Region Discovery**: Hierarchical decomposition of computation graphs into - LEAF and COMPOSITE regions with automatic pattern identification - -- **Pattern-Based Optimization**: Groups structurally-similar regions and optimizes them - together, making the process efficient and consistent - -- **TensorRT Performance Measurement**: Direct integration with TensorRT Python API for - accurate latency profiling of each Q/DQ configuration - -- **State Management**: Checkpoint/resume capability for long-running optimizations with - incremental state saving after each region - -- **Pattern Cache**: Warm-start optimization using learned schemes from previous runs, - enabling transfer learning across models - -**Core Components:** - -Autotuner Classes: - - QDQAutotuner: Main autotuner with automatic hierarchical region discovery - - QDQAutotunerBase: Base class for custom region identification strategies - -Region Management: - - Region: Hierarchical subgraph representation (nodes + children) - - RegionType: Enumeration (LEAF, COMPOSITE, ROOT) - - CombinedRegionSearch: Two-phase region discovery (partitioning + refinement) - - RegionPattern: Structural pattern analysis and matching for region grouping - -Q/DQ Insertion Points: - - InsertionScheme: Collection of Q/DQ insertion points for a region pattern - - NodeInputInsertionPoint: Q/DQ insertion at specific node inputs - - ChildRegionInputInsertionPoint: Q/DQ insertion at child region input boundaries - - RegionOutputInsertionPoint: Q/DQ insertion at region output boundaries - -Configuration & State: - - Config: Autotuning parameters (quant type, thresholds, verbosity) - - PatternCache: Top-performing schemes indexed by pattern (warm-start) - - PatternSchemes: Scheme collection and measurement results for a pattern - -Benchmarking: - - Benchmark: Abstract base class for model benchmarking - - TensorRTPyBenchmark: Benchmark using TensorRT Python API (recommended) - - TrtExecBenchmark: Benchmark using trtexec command-line tool (legacy) - -**Quick Start:** - - >>> from modelopt.onnx.quantization.autotune import QDQAutotuner, Config - >>> import onnx - >>> # Load model and initialize autotuner - >>> model = onnx.load("model.onnx") - >>> autotuner = QDQAutotuner(model) - >>> # Configure autotuning parameters - >>> config = Config(default_quant_type="int8") - >>> autotuner.initialize(config) - >>> # Generate and test Q/DQ schemes - >>> # (see workflows.region_pattern_autotuning_workflow for complete example) - -**Command-Line Interface:** - - The package can be run directly as a module: - - $ python -m modelopt.onnx.quantization.autotune --model model.onnx --output ./output - $ python -m modelopt.onnx.quantization.autotune --model model.onnx --quant-type fp8 - -**See Also:** - - - workflows.region_pattern_autotuning_workflow: Complete end-to-end optimization - - QDQAutotuner: Main autotuner class documentation - - RegionPattern: Pattern matching and signature computation """ # Core data structures @@ -101,44 +30,31 @@ PatternCache, PatternSchemes, Region, - RegionError, RegionType, ) - -# Insertion points (from dedicated module) from .insertion_points import ( ChildRegionInputInsertionPoint, NodeInputInsertionPoint, RegionOutputInsertionPoint, ResolvedInsertionPoint, ) - -# Pattern analysis from .region_pattern import RegionPattern - -# Region search from .region_search import CombinedRegionSearch -# Public API __all__ = [ - # Exceptions "AutotunerError", "AutotunerNotInitializedError", "ChildRegionInputInsertionPoint", "CombinedRegionSearch", - # Configuration and state "Config", - # Q/DQ insertion "InsertionScheme", "InvalidSchemeError", "NodeInputInsertionPoint", - "ResolvedInsertionPoint", "PatternCache", "PatternSchemes", - # Region classes "Region", - "RegionError", "RegionOutputInsertionPoint", "RegionPattern", "RegionType", + "ResolvedInsertionPoint", ] diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index 42b63c251..5418e2a58 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,31 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Common data structures and types for the QDQ Autotuner. - -This module provides the foundational classes used throughout the autotuner: - -**Exceptions:** -- Region-related: RegionError -- Autotuner-related: AutotunerError, AutotunerNotInitializedError, InvalidSchemeError - -**Region Hierarchy:** -- Region: Hierarchical subgraph representation with parent/child relationships -- RegionType: Enumeration for LEAF, COMPOSITE, and ROOT regions - -**Q/DQ Insertion Specifications:** -- InsertionScheme: Collection of insertion points with performance metrics - -**Scheme Management:** -- PatternSchemes: Multiple insertion schemes for a pattern (applies to all matching regions) -- PatternCache: Collection of top schemes for multiple patterns, used as autotuning seeds - -**Configuration:** -- Config: Autotuning parameters and Q/DQ default values -""" +"""Common data structures and types for the QDQ Autotuner.""" import hashlib -import logging from dataclasses import dataclass, field from enum import Enum from typing import TYPE_CHECKING, Any, Optional @@ -45,6 +23,7 @@ import onnx_graphsurgeon as gs import yaml +from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, NodeInputInsertionPoint, @@ -55,16 +34,7 @@ if TYPE_CHECKING: from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern -# Module logger -logger = logging.getLogger(__name__) - -# Region-related Exceptions -class RegionError(Exception): - """Base exception for region-related errors.""" - - -# Autotuner-related Exceptions class AutotunerError(Exception): """Base exception for autotuner-related errors.""" @@ -91,29 +61,12 @@ class RegionType(Enum): class Region: - """Hierarchical subgraph region in an ONNX computation graph. - - A Region represents a cohesive subgraph with well-defined boundaries, supporting: - - **Hierarchical Structure:** - - Parent/child relationships forming a multi-level hierarchy - - LEAF regions contain only direct nodes - - COMPOSITE regions contain child regions (and optionally direct nodes) - - ROOT regions encompass the entire graph - - **Node Management:** - - Direct nodes: Operations directly in this region (not in children) - - Recursive nodes: All operations including those in descendant regions - - **Boundary Tracking:** - - Input tensors: Data entering the region from outside - - Output tensors: Data leaving the region to outside consumers + """A subgraph region in an ONNX graph, used as the unit for Q/DQ insertion. - **Pattern Matching:** - - Regions with identical structure share the same pattern signature - - Pattern-based optimization applies schemes to all matching regions - - Regions are the fundamental unit for Q/DQ insertion and optimization. + Regions form a hierarchy: ROOT contains the entire graph, COMPOSITE regions + contain child regions, and LEAF regions contain only nodes. Each region tracks + its direct nodes, input/output tensors, and a pattern signature for matching + regions with identical structure. """ def __init__(self, region_id: int, level: int, region_type: RegionType): @@ -132,119 +85,64 @@ def __init__(self, region_id: int, level: int, region_type: RegionType): self.nodes: set[int] = set() self.inputs: list[str] = [] self.outputs: list[str] = [] + self.metadata: dict[str, str] = {} - # ========================================================================= - # Basic Accessors - # ========================================================================= - - def get_id(self) -> int: - """Get region ID.""" - return self.id - - def set_id(self, region_id: int) -> None: - """Set region ID (for RegionBuilder use).""" - self.id = region_id - - def get_level(self) -> int: - """Get region level in hierarchy.""" - return self.level - - def set_level(self, level: int) -> None: - """Set region level in hierarchy (for RegionBuilder use).""" - self.level = level - - def get_type(self) -> RegionType: - """Get region type.""" - return self.type - - def set_type(self, region_type: RegionType) -> None: - """Set region type (for RegionBuilder use).""" - self.type = region_type - - # ========================================================================= - # Hierarchy Management - # ========================================================================= - - def get_parent(self) -> Optional["Region"]: - """Get parent region.""" - return self.parent - - def set_parent(self, parent: Optional["Region"]) -> None: - """Set parent region.""" - self.parent = parent - - def get_children(self) -> list["Region"]: + def get_children(self, *, sort: bool = False) -> list["Region"]: """Get all child regions.""" + if sort: + return sorted( + self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants()) + ) return self.children def remove_child(self, child: "Region") -> bool: - """Remove a child region from this region's children list. - - Args: - child: The child region to remove - - Returns: - True if child was found and removed, False otherwise - """ - child_id = child.get_id() - initial_count = len(self.children) - self.children = [c for c in self.children if c.get_id() != child_id] - removed = len(self.children) < initial_count - - if removed and child.parent and child.parent.get_id() == self.id: - child.set_parent(None) - - return removed + """Remove a child region from this region's children list.""" + if child not in self.children: + return False + self.children.remove(child) + if child.parent and child.parent.id == self.id: + child.parent = None + return True def add_child(self, child: "Region") -> None: """Add a child sub-region.""" - # Prevent adding self as child - if child.get_id() == self.id: + if child.id == self.id: logger.warning(f"Cannot add region {self.id} as its own child") return - # Prevent creating cycles: check if self is already a descendant of child - if self._is_descendant_of(child): + if self.is_descendant_of(child): logger.warning( - f"Cycle detected: region {self.id} is already a descendant of region {child.get_id()}" + f"Cycle detected: region {self.id} is already a descendant of region {child.id}" ) return - # Check if child already has a different parent - if child.parent is not None and child.parent.get_id() != self.id: - old_parent_id = child.parent.get_id() + if child.parent is not None and child.parent.id != self.id: + old_parent_id = child.parent.id logger.debug( - f"Re-parenting region {child.get_id()}: moving from parent {old_parent_id} to {self.id}" + f"Re-parenting region {child.id}: moving from parent {old_parent_id} to {self.id}" ) - # Remove from old parent to maintain tree structure child.parent.remove_child(child) - # Check if child is already in children list - if any(c.get_id() == child.get_id() for c in self.children): - logger.debug(f"Region {child.get_id()} already child of {self.id}") + if any(c.id == child.id for c in self.children): + logger.debug(f"Region {child.id} already child of {self.id}") return self.children.append(child) - child.set_parent(self) + child.parent = self - def _is_descendant_of(self, potential_ancestor: "Region") -> bool: + def is_descendant_of(self, potential_ancestor: "Region") -> bool: """Check if this region is a descendant of potential_ancestor.""" visited = set() current = self.parent while current: - if current.get_id() in visited: - # Already visited, there's a cycle in parents + if current.id in visited: return False - visited.add(current.get_id()) - if current.get_id() == potential_ancestor.get_id(): + visited.add(current.id) + if current.id == potential_ancestor.id: return True current = current.parent return False - # ========================================================================= - # Node Management - # ========================================================================= - def add_node(self, node_index: int) -> None: """Add a node index to this region.""" self.nodes.add(node_index) @@ -253,65 +151,33 @@ def add_nodes(self, node_indices: list[int]) -> None: """Add multiple node indices to this region.""" self.nodes.update(node_indices) - def get_nodes(self) -> set[int]: - """Get direct node indices in this region only. - - Returns only nodes directly owned by this region, excluding nodes - in child regions. Use get_all_nodes_recursive() for complete coverage. - - Returns: - Set of node indices (absolute positions in the graph) - """ - return self.nodes - - def get_all_nodes_recursive(self, _visited: set[int] | None = None) -> set[int]: - """Get all node indices recursively, including descendants. - - Traverses the entire subtree rooted at this region, collecting nodes - from this region and all child regions recursively. + def get_nodes(self, *, sort: bool = False) -> list[int]: + """Get direct node indices in this region only.""" + if sort: + return sorted(self.nodes) + return list(self.nodes) - Args: - _visited: Internal parameter for cycle detection (do not use) - - Returns: - Set of all node indices in this region and its descendants - """ + def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants.""" if _visited is None: _visited = set() # Detect cycles - if self.id in _visited: - logger.warning(f"Cycle detected in region {self.id} during node traversal") - return set() + assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal" _visited.add(self.id) all_nodes = set(self.nodes) for child in self.children: - all_nodes.update(child.get_all_nodes_recursive(_visited)) + all_nodes.update(child.get_region_nodes_and_descendants(_visited)) return all_nodes def contains_node(self, node_index: int) -> bool: """Check if region contains a specific node (direct only).""" return node_index in self.nodes - def contains_node_recursive(self, node_index: int, _visited: set[int] | None = None) -> bool: + def contains_node_within_region_and_descendants(self, node_index: int) -> bool: """Check if region contains a node recursively.""" - if _visited is None: - _visited = set() - - # Detect cycles - if self.id in _visited: - return False - - _visited.add(self.id) - - if self.contains_node(node_index): - return True - return any(child.contains_node_recursive(node_index, _visited) for child in self.children) - - # ========================================================================= - # Input/Output Management - # ========================================================================= + return node_index in self.get_region_nodes_and_descendants() def add_input(self, tensor_name: str) -> None: """Add an input tensor name.""" @@ -323,80 +189,31 @@ def add_output(self, tensor_name: str) -> None: if tensor_name not in self.outputs: self.outputs.append(tensor_name) - def get_inputs(self) -> list[str]: - """Get region input tensors.""" - return self.inputs - - def get_outputs(self) -> list[str]: - """Get region output tensors.""" - return self.outputs - - # ========================================================================= - # Size and Query Methods - # ========================================================================= - - def get_size(self) -> int: - """Get the number of direct nodes in this region. - - Returns: - Count of nodes directly in this region (excludes child regions) - """ - return len(self.nodes) - - def get_total_size(self, _visited: set[int] | None = None) -> int: - """Get total node count recursively including all descendants. - - Computes the sum of nodes in this region and all child regions, - providing the total footprint of the region subtree. - - Args: - _visited: Internal parameter for cycle detection (do not use) - - Returns: - Total number of nodes in this region and all descendants - """ + def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants.""" if _visited is None: _visited = set() # Detect cycles - if self.id in _visited: - logger.warning(f"Cycle detected in region {self.id} during size calculation") - return len(self.nodes) + assert self.id not in _visited, ( + f"Cycle detected in region {self.id} during size calculation" + ) _visited.add(self.id) total = len(self.nodes) for child in self.children: - total += child.get_total_size(_visited) + total += child.get_size_of_region_and_descendants(_visited) return total - # ========================================================================= - # Region Operations - # ========================================================================= - def merge(self, other: "Region") -> None: - """Merge another region into this one. - - Combines the nodes and children from the other region into this region. - The other region's children become children of this region, updating - their parent references accordingly. - - Args: - other: Region to merge into this one - """ + """Merge another region into this one.""" if not other: return - # Merge direct nodes self.nodes.update(other.nodes) - # Merge children (updates their parent references) for child in other.children: self.add_child(child) - # ========================================================================= - # String Representation - # ========================================================================= - - def to_string(self) -> str: - """Print region information for debugging.""" + def __repr__(self) -> str: type_str = self.type.value return ( f"Region[id={self.id}, level={self.level}, type={type_str}, " @@ -404,12 +221,6 @@ def to_string(self) -> str: f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" ) - def __str__(self) -> str: - return self.to_string() - - def __repr__(self) -> str: - return self.to_string() - def compute_structural_signature(self, graph: gs.Graph) -> str: """Compute deterministic structural signature for pattern matching. @@ -428,48 +239,12 @@ def compute_structural_signature(self, graph: gs.Graph) -> str: Returns: Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") """ - # Import here to avoid circular dependency at runtime - from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern - - return RegionPattern.from_region(self, graph).signature - - -# ============================================================================= -# Autotuner Q/DQ Insertion Specifications -# ============================================================================= + raise NotImplementedError("Not implemented") @dataclass class InsertionScheme: - """Complete Q/DQ insertion specification for a region pattern. - - An InsertionScheme defines a complete Q/DQ configuration for a pattern, - combining both node-level and region-level insertion points. The scheme - is applied to all regions matching the pattern. - - **Scheme Identity:** - - Uniquely identified by the combination of insertion points (computed hash) - - latency_ms is a measured performance metric, not part of identity - - Two schemes with same insertion points but different latencies are considered identical - - **Application:** - - Node insertion points: Q/DQ at node inputs within the pattern - - Region insertion points: Q/DQ at child region boundaries (COMPOSITE only) - - All are resolved to actual configurations for each matching region - - **Performance Tracking:** - - latency_ms: Measured performance (inf = not yet measured) - - error: Whether this scheme encountered an error during measurement - - Used to select the best scheme for each pattern - - **Attributes:** - node_inputs: Q/DQ insertions at node inputs (list of NodeInputInsertionPoint) - child_region_inputs: Q/DQ insertions at child boundaries (list of ChildRegionInputInsertionPoint) - region_outputs: Q/DQ insertions at region outputs (list of RegionOutputInsertionPoint) - latency_ms: Measured latency in milliseconds (inf if not measured) - error: True if scheme measurement failed, False otherwise - profile_timestamp: ISO format timestamp when this scheme was profiled (None if not yet profiled) - """ + """Q/DQ insertion specification applied to all regions matching a pattern.""" node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) @@ -480,27 +255,7 @@ class InsertionScheme: @property def hash(self) -> str: - """Compute deterministic hash for scheme identity. - - The hash uniquely identifies this scheme configuration based on its - insertion points. Two schemes with identical insertion points produce - the same hash, regardless of their measured latencies. - - **Hash Input:** - - Sorted node_inputs (for deterministic ordering) - - Sorted child_region_inputs (for deterministic ordering) - - Sorted region_outputs (for deterministic ordering) - - latency_ms is EXCLUDED (performance metric, not identity) - - **Use Cases:** - - Detect duplicate schemes before measurement - - Group schemes by configuration - - Efficient scheme comparison - - Returns: - 32-character hexadecimal string (SHA-256 truncated to 128 bits) - """ - # Sort points for deterministic hashing + """Compute deterministic hash for scheme identity.""" sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) sorted_regions = sorted( [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] @@ -509,79 +264,20 @@ def hash(self) -> str: [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] ) - # Create hash input string hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" - # Compute SHA-256 hash (128 bits) return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] @property def is_empty(self) -> bool: - """Check if this is a baseline scheme with no Q/DQ insertions. - - Returns: - True if scheme has no node/region insertion points - """ - return ( - len(self.node_inputs) == 0 - and len(self.child_region_inputs) == 0 - and len(self.region_outputs) == 0 - ) - - @property - def has_error(self) -> bool: - """Check if this scheme encountered an error during measurement. - - Returns: - True if scheme has error=True, False otherwise - """ - return self.error + """Check if this is a baseline scheme with no Q/DQ insertions.""" + return not self.node_inputs and not self.child_region_inputs and not self.region_outputs @property def is_profiled(self) -> bool: - """Check if this scheme has been profiled (measured). - - A scheme is considered profiled if it has been measured (has non-infinite latency) - or has encountered an error during measurement. - - Returns: - True if scheme has been measured (latency_ms != inf) or has error, - False if scheme is waiting to be profiled (error=False and latency_ms=inf) - """ + """Check if this scheme has been profiled (measured).""" return self.error or self.latency_ms != float("inf") - @property - def num_node_insertions(self) -> int: - """Get count of node-level Q/DQ insertion points. - - Returns: - Number of NodeInputInsertionPoint entries - """ - return len(self.node_inputs) - - @property - def num_region_insertions(self) -> int: - """Get count of region-level Q/DQ insertion points. - - These specify Q/DQ insertions at child region boundaries within - COMPOSITE regions. - - Returns: - Number of ChildRegionInputInsertionPoint entries - """ - return len(self.child_region_inputs) - - @property - def num_region_output_insertions(self) -> int: - """Get count of region output insertion points. - - These specify Q/DQ insertions at outputs from child regions or nodes. - - Returns: - Number of RegionOutputInsertionPoint entries - """ - return len(self.region_outputs) - def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" return { @@ -596,19 +292,7 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": - """Create InsertionScheme from serialized dictionary. - - Reconstructs the insertion scheme from saved data, including node and - region insertion points. The hash is automatically recomputed from all - components to ensure consistency. - - Args: - data: Dictionary containing 'latency_ms', 'nodes_insertion_points', - 'child_region_inputs', and 'region_outputs' keys - - Returns: - Reconstructed InsertionScheme instance - """ + """Create InsertionScheme from serialized dictionary.""" scheme = cls() scheme.latency_ms = data.get("latency_ms", float("inf")) scheme.error = data.get("error", False) @@ -625,73 +309,23 @@ def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) ] - # Note: hash is computed from points, so we don't load it from dict - # This ensures consistency even if stored hash differs - return scheme def distance(self, other: "InsertionScheme") -> int: - """Compute edit distance between this scheme and another scheme. - - The edit distance is the minimum number of add/remove operations needed - to transform this scheme into the other scheme. This is computed as the - symmetric difference between the insertion point sets. - - **Distance Calculation:** - - Counts insertion points in self but not in other (need to be removed) - - Counts insertion points in other but not in self (need to be added) - - Considers all three types of insertion points: - * node_inputs - * child_region_inputs - * region_outputs - - Args: - other: InsertionScheme to compare against - - Returns: - Total edit distance (number of add + remove operations) - - Example: - >>> scheme1 = InsertionScheme( - ... node_inputs=[ - ... NodeInputInsertionPoint(0, 0), - ... NodeInputInsertionPoint(1, 0), - ... ] - ... ) - >>> scheme2 = InsertionScheme( - ... node_inputs=[ - ... NodeInputInsertionPoint(0, 0), - ... NodeInputInsertionPoint(2, 0), - ... ] - ... ) - >>> scheme1.distance(scheme2) # 2 (remove (1,0), add (2,0)) - 2 - """ - # Convert insertion points to sets for efficient set operations - self_nodes = set(self.node_inputs) - other_nodes = set(other.node_inputs) - - self_regions = set(self.child_region_inputs) - other_regions = set(other.child_region_inputs) - - self_region_outputs = set(self.region_outputs) - other_region_outputs = set(other.region_outputs) - - # Compute symmetric difference (elements in either set but not both) - # This gives us the total number of add + remove operations - node_distance = len(self_nodes.symmetric_difference(other_nodes)) - region_distance = len(self_regions.symmetric_difference(other_regions)) - region_output_distance = len(self_region_outputs.symmetric_difference(other_region_outputs)) - - return node_distance + region_distance + region_output_distance + """Compute edit distance between this scheme and another scheme.""" + return ( + len(set(self.node_inputs).symmetric_difference(other.node_inputs)) + + len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs)) + + len(set(self.region_outputs).symmetric_difference(other.region_outputs)) + ) def __str__(self) -> str: """String representation for debugging.""" error_str = ", error=True" if self.error else "" return ( - f"InsertionScheme(node_insertions={self.num_node_insertions}, " - f"region_insertions={self.num_region_insertions}, " - f"region_output_insertions={self.num_region_output_insertions}, " + f"InsertionScheme(node_insertions={len(self.node_inputs)}, " + f"region_insertions={len(self.child_region_inputs)}, " + f"region_output_insertions={len(self.region_outputs)}, " f"latency={self.latency_ms:.3f}ms{error_str})" ) @@ -705,26 +339,13 @@ class PatternSchemes: enables pattern-based optimization where all regions with the same structure use the same Q/DQ insertion strategy. - **Workflow:** - 1. Pattern is identified from region structure - 2. Multiple schemes are generated and tested - 3. Each scheme is measured (latency_ms) - 4. Best scheme is selected (lowest latency) - 5. Best scheme is applied to all matching regions - - **Best Scheme Selection:** - - Automatically identifies scheme with lowest latency - - Excludes schemes with errors (error=True) - - Schemes with latency_ms = inf are considered unmeasured - - best_scheme property provides easy access to optimal configuration - **Attributes:** pattern: RegionPattern defining the structural signature schemes: List of InsertionScheme candidates with measurements """ - pattern: Optional["RegionPattern"] = None # Structural pattern signature - schemes: list[InsertionScheme] = field(default_factory=list) # Candidate schemes + pattern: Optional["RegionPattern"] = None + schemes: list[InsertionScheme] = field(default_factory=list) @property def pattern_signature(self) -> str: @@ -751,7 +372,7 @@ def best_scheme_index(self) -> int: return -1 min_idx, min_latency = -1, float("inf") for idx, scheme in enumerate(self.schemes): - if not scheme.has_error and scheme.latency_ms < min_latency: + if not scheme.error and scheme.latency_ms < min_latency: min_idx = idx min_latency = scheme.latency_ms return min_idx @@ -804,14 +425,10 @@ def get_valid_schemes(self) -> list[InsertionScheme]: Returns: List of schemes that completed successfully without errors """ - return [s for s in self.schemes if not s.has_error] + return [s for s in self.schemes if not s.error] def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for serialization. - - Note: Excludes runtime objects like pattern (RegionPattern). - Only serializes metadata and schemes. - """ + """Convert to dictionary for serialization.""" return { "pattern_signature": self.pattern_signature, "pattern_size": self.pattern_size, @@ -840,12 +457,10 @@ def from_dict( Returns: Reconstructed PatternSchemes instance """ - # Import here to avoid circular dependency at runtime from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern ps = cls() - # If no pattern provided, create minimal one from saved data if pattern is None and "pattern_signature" in data: pattern = RegionPattern( signature=data["pattern_signature"], size=data.get("pattern_size", 0) @@ -880,22 +495,6 @@ class PatternCache: - Similarity (similar schemes where only better-performing one is kept) - Count limit (only top N best schemes are kept per pattern) - **Seeded Autotuning:** - - Use previous autotuning results as starting points - - Skip redundant measurements for known patterns - - Transfer learned schemes across models or runs - - **Use Cases:** - - Load pattern cache from previous run to warm-start autotuning - - Share pattern cache data across similar models - - Store best-known schemes for common patterns - - **Workflow:** - 1. After autotuning, add schemes to PatternCache (non-performant entries auto-evicted) - 2. Serialize PatternCache to file (YAML) - 3. Load PatternCache in future runs as seeds - 4. Autotuner uses seeds to initialize pattern schemes - **Attributes:** pattern_schemes: List of PatternSchemes, one per pattern minimum_distance: Minimum edit distance required between schemes in cache. @@ -904,19 +503,9 @@ class PatternCache: max_entries_per_pattern: Maximum number of schemes to keep per pattern. Only the top N best-performing schemes are kept for each pattern. Use 0 to keep all schemes (default: 32) - - Example: - >>> # Save pattern cache after autotuning - >>> cache = PatternCache(minimum_distance=4, max_entries_per_pattern=32) - >>> for schemes in autotuner.pattern_schemes.values(): - ... cache.add_pattern_schemes(schemes) # Auto-eviction happens here - >>> cache.save("pattern_cache.yaml") - >>> - >>> # Load pattern cache for next run - >>> cache = PatternCache.load("pattern_cache.yaml") - >>> autotuner.initialize(config, pattern_cache=cache) """ + # List of PatternSchemes in the cache. pattern_schemes: list[PatternSchemes] = field(default_factory=list) # Minimum distance between schemes in cache. minimum_distance: int = 4 @@ -929,19 +518,6 @@ def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: Merges new schemes with existing schemes for the same pattern, automatically evicting schemes that are non-performant based on multiple criteria. - **Automatic Eviction Strategy:** - - 1. **Error Eviction**: Schemes with errors are automatically excluded - - 2. **Duplicate Eviction**: When schemes have identical configurations (same hash), - only the one with better latency is kept - - 3. **Similarity Eviction**: When minimum_distance > 0, schemes that are too similar - to better-performing schemes are evicted - - 4. **Count Eviction**: When max_entries_per_pattern > 0, only the top N - best-performing schemes are kept per pattern - Args: pattern_schemes: PatternSchemes to add to the cache """ @@ -963,7 +539,7 @@ def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: all_schemes.extend(self.pattern_schemes[existing_idx].schemes) # Filter out schemes with errors and deduplicate by hash - valid_schemes = [s for s in all_schemes if not s.has_error] + valid_schemes = [s for s in all_schemes if not s.error] unique_schemes = {} for scheme in valid_schemes: scheme_hash = scheme.hash @@ -982,21 +558,22 @@ def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: for scheme in sorted_schemes: # Check if this scheme is too similar to any already-filtered scheme too_similar = False + scheme_to_replace = [] for existing_scheme in filtered_schemes: distance = scheme.distance(existing_scheme) if distance < self.minimum_distance: # Schemes are too similar, keep the better one + too_similar = True if scheme.latency_ms < existing_scheme.latency_ms: - # New scheme is better, remove existing and add new - filtered_schemes.remove(existing_scheme) - break - else: - # Existing scheme is better, skip new one - too_similar = True - break + # New scheme is better, mark existing for replacement + scheme_to_replace.append(existing_scheme) if not too_similar: filtered_schemes.append(scheme) + elif scheme_to_replace: + for scheme_to_replace in scheme_to_replace: + filtered_schemes.remove(scheme_to_replace) + filtered_schemes.append(scheme) sorted_schemes = filtered_schemes @@ -1049,33 +626,15 @@ def add_pattern_from_region( insertion scheme. This allows capturing known-good quantization strategies from existing models and using them as seeds for autotuning. - **Workflow:** - 1. Create RegionPattern from the region structure - 2. Identify which tensors in the region are quantized - 3. Map quantized tensors to pattern-relative insertion points: - - Node input tensors → NodeInputInsertionPoint - - Child region input tensors → ChildRegionInputInsertionPoint - - Region output tensors → RegionOutputInsertionPoint - 4. Create InsertionScheme with identified insertion points - 5. Add to pattern cache (or merge if pattern already exists) - Args: region: Region from the quantized model to analyze graph: ONNX graph containing the region quantized_tensors: Set of tensor names that have Q/DQ nodes - Example: - >>> cache = PatternCache() - >>> for region in all_regions: - ... cache.add_pattern_from_region(region, graph, quantized_tensors) - >>> cache.save("learned_patterns.yaml") """ - # Import here to avoid circular dependency at runtime from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern - # Create pattern from region pattern = RegionPattern.from_region(region, graph) - # Track insertion points scheme = InsertionScheme( node_inputs=[], child_region_inputs=[], @@ -1083,7 +642,6 @@ def add_pattern_from_region( latency_ms=float("inf"), error=False, ) - # Analyze node inputs full_insertion_scheme = pattern.get_full_insertion_scheme(region, graph) for point in full_insertion_scheme.node_inputs: temp_scheme = InsertionScheme( @@ -1097,7 +655,6 @@ def add_pattern_from_region( temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} if len(temp_tensor_names.intersection(quantized_tensors)) > 0: scheme.node_inputs.append(point) - # Analyze region boundaries (for COMPOSITE regions) if region.type == RegionType.COMPOSITE: for child_point in full_insertion_scheme.child_region_inputs: temp_scheme = InsertionScheme( @@ -1111,7 +668,6 @@ def add_pattern_from_region( temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} if len(temp_tensor_names.intersection(quantized_tensors)) > 0: scheme.child_region_inputs.append(child_point) - # Analyze region outputs for output_point in full_insertion_scheme.region_outputs: temp_scheme = InsertionScheme( node_inputs=[], @@ -1124,16 +680,12 @@ def add_pattern_from_region( temp_tensor_names = {tensor.tensor_name for tensor in temp_ips} if len(temp_tensor_names.intersection(quantized_tensors)) > 0: scheme.region_outputs.append(output_point) - # Add pattern and scheme to pattern cache pattern_schemes = PatternSchemes(pattern=pattern, schemes=[scheme]) self.add_pattern_schemes(pattern_schemes) num_points = ( len(scheme.node_inputs) + len(scheme.child_region_inputs) + len(scheme.region_outputs) ) - logger.debug( - f"Added pattern from region {region.get_id()} with {num_points} insertion points" - ) - # Add patterns from child regions + logger.debug(f"Added pattern from region {region.id} with {num_points} insertion points") if region.type == RegionType.COMPOSITE: for child_region in region.get_children(): self.add_pattern_from_region(child_region, graph, quantized_tensors) @@ -1149,11 +701,7 @@ def total_schemes(self) -> int: return sum(ps.num_schemes for ps in self.pattern_schemes) def get_all_pattern_signatures(self) -> list[str]: - """Get list of all pattern signatures in pattern cache. - - Returns: - List of pattern signature strings - """ + """Get list of all pattern signatures in pattern cache.""" return [ps.pattern_signature for ps in self.pattern_schemes] def clear(self) -> None: @@ -1161,23 +709,13 @@ def clear(self) -> None: self.pattern_schemes.clear() def merge(self, other: "PatternCache", prefer_existing: bool = True) -> None: - """Merge another PatternCache into this one. - - Args: - other: PatternCache to merge - prefer_existing: If True, keep existing patterns when there's a conflict. - If False, overwrite with other's patterns. - """ + """Merge another PatternCache into this one.""" for schemes in other.pattern_schemes: if not self.has_pattern(schemes.pattern_signature) or not prefer_existing: self.add_pattern_schemes(schemes) def to_dict(self) -> dict[str, Any]: - """Convert to dictionary for serialization. - - Returns: - Dictionary with 'minimum_distance', 'max_entries_per_pattern', and 'pattern_schemes' keys - """ + """Convert to dictionary for serialization.""" return { "minimum_distance": self.minimum_distance, "max_entries_per_pattern": self.max_entries_per_pattern, @@ -1186,50 +724,20 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "PatternCache": - """Create PatternCache from serialized dictionary. - - Note: RegionPattern objects are not restored (they're runtime objects). - Only pattern signatures and scheme data are loaded. - - Args: - data: Dictionary containing pattern cache data - - Returns: - Reconstructed PatternCache instance - """ + """Create PatternCache from serialized dictionary.""" cache = cls( minimum_distance=data.get("minimum_distance", 4), max_entries_per_pattern=data.get("max_entries_per_pattern", 32), ) for ps_data in data.get("pattern_schemes", []): - # Create PatternSchemes without pattern object (pattern=None) ps = PatternSchemes.from_dict(ps_data, pattern=None) cache.pattern_schemes.append(ps) return cache def save(self, output_path: str) -> None: - """Save pattern cache to a YAML file. - - Serializes all pattern schemes and their insertion points to a YAML file - that can be loaded later for seeded autotuning. The format matches the - autotuner state file format for consistency. - - **Contents:** - - minimum_distance: Minimum distance between schemes - - max_entries_per_pattern: Maximum number of schemes per pattern - - pattern_schemes: List of all PatternSchemes with their insertion points - - Args: - output_path: File path where the YAML pattern cache file will be written - - Example: - >>> cache = PatternCache(minimum_distance=1, max_entries_per_pattern=16) - >>> for schemes in autotuner.pattern_schemes.values(): - ... cache.add_pattern_schemes(schemes) - >>> cache.save("pattern_cache.yaml") - """ + """Save pattern cache to a YAML file.""" state = self.to_dict() with open(output_path, "w") as f: @@ -1246,29 +754,7 @@ def save(self, output_path: str) -> None: @classmethod def load(cls, input_path: str) -> "PatternCache": - """Load pattern cache from a YAML file. - - Reads a previously saved pattern cache file and reconstructs all pattern - schemes. The loaded pattern cache can be used to seed autotuning with - known-good insertion schemes. - - **Note:** RegionPattern objects are not restored since they depend on - the actual model structure. Only pattern signatures and scheme data - are loaded. - - Args: - input_path: File path to the YAML pattern cache file to load - - Returns: - PatternCache instance with all pattern schemes loaded - - Raises: - FileNotFoundError: If the input_path doesn't exist - - Example: - >>> cache = PatternCache.load("pattern_cache.yaml") - >>> autotuner.initialize(config, pattern_cache=cache) - """ + """Load pattern cache from a YAML file.""" with open(input_path) as f: state = yaml.safe_load(f) @@ -1297,55 +783,7 @@ def __str__(self) -> str: @dataclass class Config: - """Configuration parameters for QDQ autotuning. - - Controls the autotuning process including performance requirements, quantization - parameters, region building, scheme generation, and finetuning behavior. - - Attributes: - # Logging - verbose: Enable detailed logging of autotuning progress (default: False) - - # Quantization Parameters - default_q_scale: Default scale parameter for Q/DQ nodes. Controls quantization - granularity. Typical range: 0.01-0.1 (default: 0.1) - default_q_zero_point: Default zero-point for Q/DQ nodes. Use 0 for signed int8, - 128 for unsigned uint8 (default: 0) - default_quant_type: Quantization type for Q/DQ nodes. Options: "int8" (default), "fp8" - - # Region Builder Settings - maximum_sequence_region_size: Maximum number of nodes in a sequence region during - top-down refinement. Prevents overly large merged regions (default: 10) - minimum_topdown_search_size: Minimum number of nodes in a region to trigger - top-down search during region building (default: 10) - - # Scheme Generation Settings - top_percent_to_mutate: Top percentage of best schemes to use as mutation seeds - during scheme generation. Range: 0.0-1.0 (default: 0.1 = top 10%) - minimum_schemes_to_mutate: Minimum number of schemes to keep as mutation seeds, - even if top_percent_to_mutate results in fewer (default: 10) - maximum_mutations: Maximum number of mutations to apply to a single scheme - during generation (default: 3) - maximum_generation_attempts: Maximum attempts to generate a unique new scheme - before giving up (default: 100) - - # Pattern Cache Settings - pattern_cache_minimum_distance: Minimum edit distance required between schemes in cache. - When adding schemes, if a scheme is too similar (distance < minimum_distance) - to an existing scheme, only the better-performing one is kept (default: 4) - pattern_cache_max_entries_per_pattern: Maximum number of schemes to keep per pattern - in pattern cache. Only the top N best-performing schemes are kept for each pattern. - Use 0 to keep all schemes (default: 32) - - Example: - >>> config = Config( - ... verbose=True, # Enable detailed logging - ... top_percent_to_mutate=0.2, # Use top 20% schemes as seeds - ... pattern_cache_minimum_distance=2, # Require more diversity in cache - ... ) - >>> autotuner = QDQAutotuner(model) - >>> autotuner.initialize(config) - """ + """Configuration parameters for QDQ autotuning.""" # Logging verbose: bool = False diff --git a/modelopt/onnx/quantization/autotune/region_inspect.py b/modelopt/onnx/quantization/autotune/region_inspect.py index 32b7cc58a..8c0950fe9 100644 --- a/modelopt/onnx/quantization/autotune/region_inspect.py +++ b/modelopt/onnx/quantization/autotune/region_inspect.py @@ -89,9 +89,11 @@ def inspect_region_search( logger.info("Analyzing region structure") all_regions = [] for i, region in enumerate(regions): - for child in region.get_children(): - if not include_all_regions and not has_quantizable_operations(child, graph): - region.remove_child(child) + region.children = [ + c + for c in region.get_children() + if include_all_regions or has_quantizable_operations(c, graph) + ] if not include_all_regions and not has_quantizable_operations(region, graph): logger.debug(f"Filtered out region {i} (no quantizable operations)") continue @@ -139,7 +141,7 @@ def inspect_region_search( bar = "█" * min(count, 50) logger.debug(f" {size:4d} nodes: {bar} ({count} regions)") - return regions + return all_regions def main(): diff --git a/modelopt/onnx/quantization/autotune/region_pattern.py b/modelopt/onnx/quantization/autotune/region_pattern.py index 9abd42fd4..ab87abdd6 100644 --- a/modelopt/onnx/quantization/autotune/region_pattern.py +++ b/modelopt/onnx/quantization/autotune/region_pattern.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,28 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Region Pattern Signature Generator. - -Provides structural pattern analysis for regions in ONNX computation graphs. -This module enables: -- Pattern-based region grouping by structural similarity -- Deterministic signature generation for pattern matching -- Resolution of insertion points to actual tensor names -- Support for both node-level and region-level Q/DQ insertion - -Key concepts: -- NodeInputInsertionPoint: Specifies Q/DQ insertion at a node's input -- ChildRegionInputInsertionPoint: Specifies Q/DQ insertion at a child region's input boundary -- RegionOutputInsertionPoint: Specifies Q/DQ insertion at a region output (child or node) -- Pattern matching: Groups regions with identical structure for shared optimization -""" +"""Region pattern signature generator for grouping structurally similar regions.""" import hashlib -import logging -from typing import Union +from typing import Union, overload import onnx_graphsurgeon as gs +from modelopt.onnx.op_types import get_symmetric_ops from modelopt.onnx.quantization.autotune.common import InsertionScheme, Region from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, @@ -43,82 +29,32 @@ ResolvedInsertionPoint, ) -# Module logger -logger = logging.getLogger(__name__) - -# Commutative/symmetric operations where operand order doesn't matter -SYMMETRIC_OPERATIONS = { - "Add", - "Mul", - "And", - "Or", - "Xor", - "Equal", - "Max", - "Min", - "Sum", - "Mean", - "BitwiseAnd", - "BitwiseOr", - "BitwiseXor", -} - class RegionPattern: - """Represents a structural pattern of a region. - - The pattern captures the topology and operation types in a region, - enabling pattern matching and region comparison. Patterns are hashable - and can be used as dictionary keys for efficient grouping and lookup. - - Two RegionPattern objects are considered equal if they have the same - signature string, regardless of their size (which represents instance-specific - node count). - - Attributes: - signature: The unique signature string identifying the pattern - size: Total node count for this pattern instance - """ - - # ========================================================================= - # Initialization - # ========================================================================= + """Represents a structural pattern of a region.""" def __init__(self, signature: str, size: int): - """Initialize a region pattern. - - Args: - signature: The signature string representing the pattern structure - size: Total size (node count) of the region - """ + """Initialize a region pattern.""" self.signature = signature self.size = size - # ========================================================================= - # Properties - # ========================================================================= - @property def is_empty(self) -> bool: - """Check if pattern represents an empty region.""" - return self.signature == "EMPTY" or self.size == 0 + """Check if the pattern represents an empty region.""" + return self.size == 0 @property def is_composite(self) -> bool: - """Check if pattern represents a composite region.""" + """Check if the pattern represents a composite region.""" return self.signature.startswith("COMPOSITE(") @property def is_leaf(self) -> bool: - """Check if pattern represents a leaf region (no composite structure).""" + """Check if the pattern represents a leaf region (no composite structure).""" return not self.is_composite and not self.is_empty - # ========================================================================= - # Special Methods (Python Protocol) - # ========================================================================= - def __str__(self) -> str: - """String representation showing just the signature.""" + """String representation of the pattern.""" return self.signature def __repr__(self) -> str: @@ -135,66 +71,32 @@ def __hash__(self) -> int: """Hash based on signature for use as dict key.""" return hash(self.signature) - # ========================================================================= - # Public Query Methods - # ========================================================================= - def get_hash(self) -> str: - """Get a 128-bit cryptographic hash of the pattern signature. - - Uses SHA-256 (truncated to 128 bits) to generate a compact, deterministic - hash for efficient pattern comparison and storage. This hash is more - compact than the full signature for storage and comparison purposes. - - Returns: - Hexadecimal string representation of the hash (32 characters) - - Example: - >>> pattern = RegionPattern.from_region(region, graph) - >>> hash_val = pattern.get_hash() # Returns 32 hex characters - >>> print(f"Pattern hash: {hash_val}") - """ - # SHA-256 truncated to 128 bits = 32 hex characters + """Get a 128-bit cryptographic hash of the pattern signature.""" return hashlib.sha256(self.signature.encode("utf-8")).hexdigest()[:32] def get_short_signature(self, max_length: int = 80) -> str: - """Get a truncated version of the signature for display purposes. - - Args: - max_length: Maximum length of the returned string (default: 80) - - Returns: - Truncated signature with '...' suffix if needed - """ + """Get a truncated version of the signature for display purposes.""" if len(self.signature) <= max_length: return self.signature return self.signature[: max_length - 3] + "..." - # ========================================================================= - # Public Pattern Matching and Construction - # ========================================================================= - @classmethod def from_region(cls, region: Region, graph: gs.Graph) -> "RegionPattern": - """Compute a structural pattern for a region. - - The pattern captures: - - Direct node operations in the region - - Structure of sub-regions (recursively) - - Handles symmetric operations consistently - - Sorts sub-regions by size for determinism - - Args: - region: The region to compute pattern for - graph: The ONNX graph containing the nodes - - Returns: - RegionPattern object containing the signature and metadata - """ + """Compute a structural pattern for a region.""" signature_str = cls._compute_signature_recursive(region, graph) - total_size = region.get_total_size() + total_size = len(region.get_region_nodes_and_descendants()) return cls(signature_str, total_size) + @overload + def matches(self, other: "RegionPattern") -> bool: ... + @overload + def matches(self, other: Region, graph: gs.Graph, scheme: None = None) -> list[int] | None: ... + @overload + def matches( + self, other: Region, graph: gs.Graph, scheme: InsertionScheme + ) -> set[ResolvedInsertionPoint]: ... + def matches( self, other: Union["RegionPattern", Region], @@ -203,21 +105,6 @@ def matches( ) -> bool | list[int] | set[ResolvedInsertionPoint] | None: """Check if this pattern matches another pattern or region. - This method provides three distinct behaviors depending on the arguments: - - 1. **Pattern-to-pattern comparison** (other is RegionPattern, scheme is None): - Returns bool indicating structural equivalence. - - 2. **Pattern-to-region matching** (other is Region, scheme is None): - Returns list of node IDs in pattern order if match succeeds, None otherwise. - - 3. **Pattern-to-region with insertion scheme** (other is Region, scheme provided): - Returns set of resolved insertion points where Q/DQ should be inserted, considering: - - NodeInputInsertionPoints from the scheme (node-level Q/DQ) - - ChildRegionInputInsertionPoints from the scheme (child region input Q/DQ) - - RegionOutputInsertionPoints from the scheme (region output Q/DQ) - Returns empty set if pattern doesn't match. - Args: other: Either a RegionPattern or Region to compare with graph: Required when other is a Region (for computing its pattern) @@ -226,11 +113,9 @@ def matches( to resolve to tensor names Returns: - - bool: If other is RegionPattern, True if patterns match - - List[int]: If other is Region and scheme is None, list of node IDs - in pattern order (None if no match) - - Set[ResolvedInsertionPoint]: If other is Region and scheme is provided, - set of resolved insertion points for Q/DQ insertion (empty set if no match) + - True if other is RegionPattern and patterns match + - List of node IDs in pattern order if other is Region and scheme is None, None if no match + - Set of resolved insertion points for Q/DQ insertion if other is Region and scheme is provided Raises: ValueError: If other is Region but graph is not provided, or if scheme @@ -238,142 +123,64 @@ def matches( TypeError: If other is neither RegionPattern nor Region """ if isinstance(other, RegionPattern): - # Behavior 1: Pattern-to-pattern comparison if scheme is not None: raise ValueError("scheme parameter can only be used when matching against a Region") return self._matches_pattern(other) elif isinstance(other, Region) and scheme is None: - # Behavior 2: Pattern-to-region matching (returns node IDs) return self._matches_region(other, graph) elif isinstance(other, Region) and scheme is not None: if graph is None: raise ValueError("graph parameter is required") - # Verify the region matches this pattern + region_pattern = RegionPattern.from_region(other, graph) if self != region_pattern: return set() resolved_ips = set() - # Resolve NodeInputInsertionPoints to tensor names for ip in scheme.node_inputs: resolved_ips.update(ip.resolve(other, graph)) - # Resolve ChildRegionInputInsertionPoints to tensor names for ip in scheme.child_region_inputs: resolved_ips.update(ip.resolve(other, graph)) - # Resolve RegionOutputInsertionPoints to tensor names for ip in scheme.region_outputs: resolved_ips.update(ip.resolve(other, graph)) return resolved_ips else: raise TypeError(f"Expected RegionPattern or Region, got {type(other).__name__}") - # ========================================================================= - # Private Pattern Matching Helpers - # ========================================================================= - def _matches_pattern(self, other: "RegionPattern") -> bool: - """Internal function: Match this pattern against another pattern. - - Args: - other: Another RegionPattern to compare with - - Returns: - True if patterns are structurally equivalent, False otherwise - """ + """Internal function: Match this pattern against another pattern.""" return self == other def _matches_region(self, region: Region, graph: gs.Graph | None) -> list[int] | None: - """Internal function: Match this pattern against a region. - - Args: - region: The region to match against - graph: The ONNX graph containing the nodes - - Returns: - List of node IDs in match order if pattern matches, None otherwise. - Match order follows the pattern computation order: - - Direct nodes of the region (sorted) - - Then recursively, nodes from child regions (in child sort order) - - Raises: - ValueError: If graph is not provided - """ + """Internal function: Match this pattern against a region.""" if graph is None: raise ValueError("graph parameter is required when matching against a Region") - # Compute pattern for the region region_pattern = RegionPattern.from_region(region, graph) - # Check if patterns match if self == region_pattern: - # Return node IDs in match order (same as signature computation order) return self._collect_nodes_in_match_order(region) else: return None def get_full_insertion_scheme(self, region: Region, graph: gs.Graph) -> InsertionScheme: - """Get all possible insertion points for a region in a single InsertionScheme. - - This method first verifies that the region matches this pattern (raises if not). - It then collects all three types of insertion points: - 1. Node input insertion points (Q/DQ at node inputs within the region) - 2. Child region input insertion points (Q/DQ at child region input boundaries) - 3. Region output insertion points (Q/DQ at region output boundaries) - - The returned InsertionScheme contains all possible Q/DQ insertion - locations for this region pattern. This can be used as: - - A baseline scheme with all possible insertions - - A starting point for optimization algorithms - - A comprehensive view of all insertion opportunities - - Important: Pattern-relative indices in the returned scheme are based on - sorted child/node ordering. The sorting order (-level, size) MUST match - insertion_points.py for correct resolution. - - Note: The returned scheme has no child region schemes specified, - latency is set to infinity (unmeasured), and error flag is False. - - Args: - region: The region to analyze - graph: The ONNX graph containing the nodes - - Returns: - InsertionScheme containing all possible insertion points for this region - - Raises: - AssertionError: If the region doesn't match this pattern - """ - # Verify that the region matches this pattern + """Get all possible insertion points for a region in a single InsertionScheme.""" region_pattern = RegionPattern.from_region(region, graph) assert self == region_pattern, "Region pattern mismatch" scheme = InsertionScheme() - # Collect all node input insertion points scheme.node_inputs = NodeInputInsertionPoint.collect_from_region(region, graph) - # Collect all child region input insertion points (at child boundaries) scheme.child_region_inputs = ChildRegionInputInsertionPoint.collect_from_region( region, graph ) - # Collect all region output insertion points scheme.region_outputs = RegionOutputInsertionPoint.collect_from_region(region, graph) return scheme def format_tree(self, region: Region, graph: gs.Graph, indent: int = 0) -> str: - """Format this pattern and region as a human-readable tree. - - Useful for debugging and visualization. - - Args: - region: The region associated with this pattern - graph: The ONNX graph - indent: Indentation level - - Returns: - Formatted string representation - """ + """Format this pattern and region as a human-readable tree.""" prefix = " " * indent - result = f"{prefix}Region {region.get_id()}: {self.signature} (size={self.size})\n" + result = f"{prefix}Region {region.id}: {self.signature} (size={self.size})\n" for child in region.get_children(): child_pattern = RegionPattern.from_region(child, graph) @@ -381,289 +188,113 @@ def format_tree(self, region: Region, graph: gs.Graph, indent: int = 0) -> str: return result - # ========================================================================= - # Static Utility Methods - # ========================================================================= - @staticmethod def _collect_nodes_in_match_order(region: Region) -> list[int]: - """Collect node IDs in the same order as signature computation. - - This follows the traversal order used by _compute_signature_recursive: - 1. Direct nodes of the region (sorted by node index) - 2. Recursively, nodes from child regions (children sorted by -level, then size) - - The child sorting order MUST match _compute_signature_recursive and - insertion_points.py for correct pattern-relative index alignment. - - Args: - region: The region to collect nodes from - - Returns: - List of node IDs in match order - """ + """Collect node IDs in the same order as signature computation.""" node_ids = [] - # Add direct nodes of this region (sorted) - node_ids.extend(sorted(region.get_nodes())) - - # Get children and sort them the same way as signature computation - # CRITICAL: This sorting must match _compute_signature_recursive and insertion_points.py - # Sort by: 1) level (descending - higher level first), 2) size (ascending) - children = region.get_children() - sorted_children = sorted(children, key=lambda r: (-r.get_level(), r.get_total_size())) + node_ids.extend(region.get_nodes(sort=True)) + sorted_children = region.get_children(sort=True) - # Recursively collect nodes from children in order for child in sorted_children: node_ids.extend(RegionPattern._collect_nodes_in_match_order(child)) return node_ids - # --- Signature Computation --- - @staticmethod def _compute_signature_recursive(region: Region, graph: gs.Graph) -> str: - """Recursively compute structural signature for a region. - - The signature captures: - - Node operations and their key parameters (for LEAF regions) - - Hierarchical structure with child patterns (for COMPOSITE regions) - - Deterministic ordering (sorted nodes and children) - - Normalized handling of symmetric/commutative operations - - Signature formats: - - Empty region: "EMPTY" - - Leaf region: "Op1->Op2->Op3" or "Op1[params]->Op2[params]" - - Composite with nodes: "COMPOSITE(nodes|child1+child2)" - - Composite without nodes: "COMPOSITE(child1+child2)" - - Child Sorting: - - Children are sorted by (-level, size) for deterministic signatures - - This order MUST match insertion_points.py for correct pattern-relative indexing - - Higher-level (more abstract) children come first - - Within same level, smaller children come first + """Recursively compute structural signature for a region.""" + nodes_list = list(graph.nodes) + node_indices_set = set(region.get_nodes()) - Args: - region: The region to process - graph: The ONNX graph containing the nodes + node_ops = [ + RegionPattern._make_node_with_params_signature(nodes_list[idx], graph, node_indices_set) + for idx in sorted(node_indices_set) + if idx < len(nodes_list) + ] - Returns: - Deterministic signature string representing the region structure - """ - # Collect direct node operations in this region - node_ops = [] - nodes_list = list(graph.nodes) - node_indices_set = region.get_nodes() - - for node_idx in sorted(node_indices_set): - if node_idx < len(nodes_list): - node = nodes_list[node_idx] - # Include operation type and key parameters - # Pass region node indices for symmetric operation handling - node_sig = RegionPattern._make_node_with_params_signature( - node, graph, node_indices_set - ) - node_ops.append(node_sig) - - # Get child regions - children = region.get_children() - - if not children and not node_ops: - # Empty region (edge case) - return "EMPTY" + sorted_children = region.get_children(sort=True) - if not children: - # LEAF region - only direct nodes, no hierarchical structure - return RegionPattern._make_node_signature(node_ops) + if not sorted_children and not node_ops: + return "EMPTY" - # COMPOSITE region - has hierarchical structure with children - # Sort children deterministically for consistent signatures - # CRITICAL: This sorting must match insertion_points.py for pattern-relative index alignment - # Sort by: 1) level (descending - higher level first), 2) size (ascending) - sorted_children = sorted(children, key=lambda r: (-r.get_level(), r.get_total_size())) + if not sorted_children: + return "->".join(node_ops) - # Recursively compute child signatures - child_signatures = [] - for child in sorted_children: - child_sig = RegionPattern._compute_signature_recursive(child, graph) - child_signatures.append(child_sig) + child_sigs = "+".join( + [RegionPattern._compute_signature_recursive(child, graph) for child in sorted_children] + ) - # Combine node operations and child signatures if node_ops: - # Has both direct nodes and hierarchical children - node_sig = RegionPattern._make_node_signature(node_ops) - return f"COMPOSITE({node_sig}|{RegionPattern._join_signatures(child_signatures)})" - else: - # Only children, no direct nodes in this region - return f"COMPOSITE({RegionPattern._join_signatures(child_signatures)})" + node_sig = "->".join(node_ops) + return f"COMPOSITE({node_sig}|{child_sigs})" + return f"COMPOSITE({'+'.join(child_sigs)})" @staticmethod - def _make_node_with_params_signature( + def _get_symmetric_input_signature( node: gs.Node, graph: gs.Graph, region_node_indices: set - ) -> str: - """Create signature for a single node including its parameters. - - Includes operation type and key attributes that affect behavior. - For symmetric/commutative operations (Add, Mul, etc.), normalizes - input order to ensure consistent signatures regardless of operand order. - Ensures deterministic ordering by sorting attributes by key name. - - Args: - node: The ONNX node - graph: The ONNX graph containing all nodes - region_node_indices: Set of node indices in the current region - - Returns: - Signature string examples: - - "Relu" - Simple operation without attributes - - "Conv[dilations=1x1,kernel_shape=3x3]" - Operation with attributes - - "Add" - Symmetric op with sorted input sources - - "Mul[axis=1]" - Symmetric op with both - """ - op = node.op - - # Handle symmetric operations - normalize input order - if op in SYMMETRIC_OPERATIONS and len(node.inputs) > 1: - # Get input source information for normalization - input_sources = [] - nodes_list = list(graph.nodes) - - # Build node index lookup for efficient producer finding - node_to_idx = {id(n): idx for idx, n in enumerate(nodes_list)} - - for inp in node.inputs: - if inp is None or not hasattr(inp, "inputs") or not inp.inputs: - # Input from graph input or constant - input_sources.append(("external", "input-or-constant")) - else: - # Input from another node's output - producer_node = inp.inputs[0] if inp.inputs else None - if producer_node and id(producer_node) in node_to_idx: - producer_idx = node_to_idx[id(producer_node)] - # Check if producer is in the same region - if producer_idx in region_node_indices: - # Use relative position: 'internal' + producer op type - input_sources.append(("internal", producer_node.op)) - else: - # Producer outside region - input_sources.append(("external", producer_node.op)) - else: - # Unknown producer - input_sources.append(("external", "unknown")) - - # Sort input sources for deterministic ordering - # This ensures Add(A,B) and Add(B,A) have the same signature - sorted_sources = sorted(input_sources) - - # Create source signature - source_sig = ",".join(f"{src[0]}:{src[1]}" for src in sorted_sources) - - # If node has no attributes, return op with input signature - if not node.attrs: - return f"{op}<{source_sig}>" - - # Otherwise, will add input signature after attributes - has_symmetric_inputs = True - else: - has_symmetric_inputs = False + ) -> str | None: + """Compute normalized input source signature for symmetric operations.""" + if node.op not in get_symmetric_ops() or len(node.inputs) <= 1: + return None - # Handle non-symmetric operations or symmetric ops without multiple inputs - if not node.attrs and not has_symmetric_inputs: - return op + nodes_list = list(graph.nodes) + node_to_idx = {id(n): idx for idx, n in enumerate(nodes_list)} - # Extract and format key attributes (only if node has attributes) - if node.attrs: - # Sort attributes alphabetically for deterministic ordering - attr_parts = [] - for key in sorted(node.attrs.keys()): - value = node.attrs[key] - - # Format different attribute types deterministically - if isinstance(value, (list, tuple)): - # Format lists/tuples compactly - # Use 'x' separator for numeric arrays (common in ONNX) - if len(value) > 0 and all(isinstance(v, (int, float)) for v in value): - # Format each element consistently - if all(isinstance(v, int) for v in value): - value_str = "x".join(str(v) for v in value) - else: - # Mixed int/float - format floats with limited precision - value_str = "x".join( - f"{v:.4g}" if isinstance(v, float) else str(v) for v in value - ) - else: - # Non-numeric or mixed types - use comma separator - value_str = ",".join(str(v) for v in value) - elif isinstance(value, float): - # Format floats with limited precision to avoid floating point noise - value_str = f"{value:.4g}" - elif isinstance(value, bool): - # Format booleans as 0/1 for compactness - value_str = "1" if value else "0" - elif isinstance(value, bytes): - # Format bytes as hex string (truncated for long values) - hex_str = value.hex() - value_str = hex_str if len(hex_str) <= 16 else f"{hex_str[:16]}..." + input_sources = [] + for inp in node.inputs: + if inp is None or not hasattr(inp, "inputs") or not inp.inputs: + input_sources.append(("external", "input-or-constant")) + else: + producer_node = inp.inputs[0] if inp.inputs else None + if producer_node and id(producer_node) in node_to_idx: + producer_idx = node_to_idx[id(producer_node)] + location = "internal" if producer_idx in region_node_indices else "external" + input_sources.append((location, producer_node.op)) else: - # Default: convert to string - value_str = str(value) - - attr_parts.append(f"{key}={value_str}") + input_sources.append(("external", "unknown")) - # Build final signature with attributes - attr_sig = f"[{','.join(attr_parts)}]" - - # Add symmetric input signature if applicable - if has_symmetric_inputs: - return f"{op}{attr_sig}<{source_sig}>" - else: - return f"{op}{attr_sig}" - else: - # No attributes - already handled above for symmetric ops - return op + sorted_sources = sorted(input_sources) + return ",".join(f"{loc}:{op}" for loc, op in sorted_sources) @staticmethod - def _make_node_signature(ops: list[str]) -> str: - """Create signature from list of node operations. - - Handles single and multiple operations, including symmetric operations. - - Args: - ops: List of operation signatures (may include parameters) - - Returns: - Signature string for the operations - """ - if not ops: - return "" - - if len(ops) == 1: - return ops[0] - - # Multiple operations - create sequential signature - return "->".join(ops) + def _format_attr_value(value: object) -> str: + """Format an attribute value for inclusion in a signature.""" + if isinstance(value, (list, tuple)): + if len(value) > 0 and all(isinstance(v, (int, float)) for v in value): + if all(isinstance(v, int) for v in value): + return "x".join(str(v) for v in value) + return "x".join(f"{v:.4g}" if isinstance(v, float) else str(v) for v in value) + return ",".join(str(v) for v in value) + if isinstance(value, float): + return f"{value:.4g}" + if isinstance(value, bool): + return "1" if value else "0" + if isinstance(value, bytes): + hex_str = value.hex() + return hex_str if len(hex_str) <= 16 else f"{hex_str[:16]}..." + return str(value) @staticmethod - def _join_signatures(signatures: list[str]) -> str: - """Join multiple child signatures. - - Sorts signatures alphabetically to ensure deterministic ordering. - This is critical for pattern matching and comparison. - - Args: - signatures: List of child signatures - - Returns: - Combined signature string with deterministic ordering - """ - if not signatures: - return "" + def _make_node_with_params_signature( + node: gs.Node, graph: gs.Graph, region_node_indices: set + ) -> str: + """Create signature for a single node including its parameters.""" + op = node.op + sym_sig = RegionPattern._get_symmetric_input_signature(node, graph, region_node_indices) - if len(signatures) == 1: - return signatures[0] + attr_sig = "" + if node.attrs: + attr_parts = [ + f"{key}={RegionPattern._format_attr_value(node.attrs[key])}" + for key in sorted(node.attrs.keys()) + ] + attr_sig = f"[{','.join(attr_parts)}]" - # Sort signatures alphabetically for deterministic ordering - # This ensures that parallel/sibling regions always produce - # the same combined signature regardless of traversal order - sorted_sigs = sorted(signatures) - return "+".join(sorted_sigs) + if attr_sig and sym_sig: + return f"{op}{attr_sig}<{sym_sig}>" + if sym_sig: + return f"{op}<{sym_sig}>" + if attr_sig: + return f"{op}{attr_sig}" + return op diff --git a/modelopt/onnx/quantization/autotune/region_search.py b/modelopt/onnx/quantization/autotune/region_search.py index 62906fd50..227e9b9ed 100644 --- a/modelopt/onnx/quantization/autotune/region_search.py +++ b/modelopt/onnx/quantization/autotune/region_search.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,218 +13,65 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Region Search - Hierarchical Region Discovery and Partitioning. - -This module provides sophisticated algorithms for discovering and organizing regions -in ONNX computation graphs. It creates hierarchical region structures that respect -computational patterns like divergence, convergence, and sequential operations. - -**Core Functionality:** -- **Two-Phase Region Discovery**: Combines bottom-up partitioning with top-down refinement -- **Pattern Recognition**: Identifies divergence/convergence patterns in computation flow -- **Hierarchical Structure**: Creates COMPOSITE regions containing LEAF child regions -- **Boundary Computation**: Automatically determines region input/output tensors -- **Graph Analysis**: Pre-computes reachability and data flow information - -**Key Algorithms:** - -1. **Bottom-Up Partitioning (RegionPartitioner)**: - - Traverses graph from inputs to outputs - - Identifies divergent nodes where computation branches - - Finds convergence points where branches rejoin - - Creates initial LEAF regions based on these patterns - -2. **Top-Down Refinement (TopDownRegionBuilder)**: - - Merges converged sub-patterns within regions - - Splits long sequences into optimal-sized regions - - Creates hierarchical COMPOSITE region structures - - Respects operation boundaries (Conv, Gemm, etc.) - -3. **Combined Strategy (CombinedRegionSearch)**: - - Orchestrates both phases for comprehensive region discovery - - Produces well-formed hierarchical regions covering entire graph - -**Region Types:** -- **LEAF regions**: Contain actual graph nodes (basic building blocks) -- **COMPOSITE regions**: Contain child regions (hierarchical organization) -- **ROOT region**: Single region containing all graph nodes (for analysis) - -**Use Cases:** -- Graph partitioning for distributed execution -- Identifying optimization boundaries for quantization/pruning -- Creating hierarchical abstractions of computation -- Analyzing graph structure and computational patterns - -**Key Classes:** -- **RegionSearchBase**: Base class with common graph analysis utilities -- **CombinedRegionSearch**: Main two-phase region discovery algorithm -- **RegionPartitioner**: Bottom-up partitioning based on divergence/convergence -- **TopDownRegionBuilder**: Top-down refinement creating hierarchical structure -""" - -import argparse -import logging +"""Hierarchical region discovery and partitioning for ONNX graphs.""" + import sys -from collections import Counter, deque +from collections import deque -import onnx import onnx_graphsurgeon as gs +from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.autotune.common import Region, RegionType -from modelopt.onnx.quantization.autotune.insertion_points import has_quantizable_operations from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices -# Module logger -logger = logging.getLogger(__name__) - - -def enable_debug(): - """Enable debug-level logging for the region search module.""" - global logger - logger.setLevel(logging.DEBUG) - - DEFAULT_MAX_STEPS = 10 DEFAULT_MAX_NODES_TO_SHOW = 20 +MAX_PROBE_STEPS_AFTER_CONVERGE = 3 class RegionSearchBase: - """Base class for region search algorithms providing common graph analysis utilities. - - This class serves as a foundation for region-based graph analysis algorithms by - providing essential data structures and methods for: - - Graph traversal and reachability analysis - - Divergence/convergence pattern detection - - Region boundary computation - - Tensor flow tracking - - **Core Data Structures:** - - **tensor_users_map**: Maps tensor names to node indices that consume them. - Used to efficiently find divergence points and track data flow. - - **forward_reachable_nodes_map**: Pre-computed forward reachability for all nodes. - Maps each node to all nodes reachable from it (with distances). - - **root**: Root region containing all graph nodes, used as search space. - - **Key Algorithms:** - - **Divergence Detection**: Identifies nodes whose outputs branch to multiple consumers - - **Convergence Detection**: Finds nodes where multiple branches rejoin - - **Boundary Computation**: Determines input/output tensors for regions - - **Reachability Analysis**: Computes forward-reachable nodes with distances - - **Design Pattern:** - This is a base class meant to be subclassed. Subclasses implement specific - region formation strategies (e.g., bottom-up partitioning, top-down refinement) - while reusing the common analysis utilities provided here. - - **Performance:** - Pre-computation in __init__ scales with graph size: - - tensor_users_map: O(E) where E = number of edges - - forward_reachable_nodes_map: O(N * (N + E)) where N = number of nodes - - For large graphs, initialization may take significant time but enables - efficient queries during region formation. - - Attributes: - graph: The ONNX computation graph (onnx_graphsurgeon.Graph) - root: Root region containing all nodes in the graph - tensor_users_map: Mapping from tensor names to consuming node indices - forward_reachable_nodes_map: Pre-computed forward reachability for all nodes - - Example: - >>> # Typically used as a base class - >>> class MyRegionSearch(RegionSearchBase): - ... def find_regions(self): - ... # Use inherited utilities like _is_node_divergent() - ... pass - """ + """Base class for region search algorithms providing common graph analysis utilities.""" def __init__( - self, graph: gs.Graph, root: Region | None = None, max_steps: int = DEFAULT_MAX_STEPS + self, + graph: gs.Graph, + root: Region | None = None, + max_steps: int = DEFAULT_MAX_STEPS, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, ): - """Initialize the base region search with graph analysis. - - Performs pre-computation of essential data structures for efficient - region analysis: - 1. Creates or validates root region containing all nodes - 2. Builds tensor-to-users mapping for divergence detection - 3. Pre-computes forward reachability for convergence detection - - Args: - graph: The ONNX graph to analyze (onnx_graphsurgeon.Graph) - root: Optional root region. If None, creates one containing all nodes. - max_steps: Maximum distance for forward reachability pre-computation. - Limits memory usage and computation time for large graphs. - - Note: - Initialization time scales with graph complexity. For graphs with - thousands of nodes, this may take several seconds. - """ + """Initialize the base region search with graph analysis.""" self.graph = graph if root is None: root = self._build_root_region() self.root = root - self.tensor_users_map = get_tensor_consumer_node_indices(self.graph) - self.forward_reachable_nodes_map = self._build_forward_reachable_nodes_map( - max_steps=max_steps - ) + if tensor_users_map is None: + tensor_users_map = get_tensor_consumer_node_indices(self.graph) + self.tensor_users_map = tensor_users_map + if forward_reachable_nodes_map is None: + forward_reachable_nodes_map = self._build_forward_reachable_nodes_map( + max_steps=max_steps + ) + self.forward_reachable_nodes_map = forward_reachable_nodes_map def _build_root_region(self) -> Region: - """Create a root region containing all nodes in the graph. - - The root region serves as the universal search space for region - formation algorithms. It represents the entire computation graph - as a single region before any partitioning. - - Returns: - Region of type ROOT containing all graph nodes. - """ + """Create a root region containing all nodes in the graph.""" root = Region(region_id=0, level=0, region_type=RegionType.ROOT) for node_idx in range(len(self.graph.nodes)): root.add_node(node_idx) - for tensor_name in root.get_inputs(): + for tensor_name in root.inputs: root.add_input(tensor_name) - for tensor_name in root.get_outputs(): + for tensor_name in root.outputs: root.add_output(tensor_name) return root def _is_tensor_divergent(self, tensor_name: str) -> bool: - """Check if a tensor is consumed by multiple nodes (divergent). - - A divergent tensor indicates branching in the computation graph, - where one operation's output feeds into multiple downstream operations. - - Args: - tensor_name: Name of the tensor to check - - Returns: - True if tensor has more than one consumer, False otherwise - """ + """Check if a tensor is consumed by multiple nodes (divergent).""" return len(self.tensor_users_map.get(tensor_name, [])) > 1 def _is_node_divergent(self, node_idx: int) -> bool: - """Check if a node has outputs that branch to multiple consumers. - - A divergent node is one that produces outputs consumed by multiple - downstream nodes, creating branches in the computation graph. These - nodes are important boundaries for region formation. - - **Significance:** - - Divergent nodes often represent natural region boundaries - - They indicate where computation splits into parallel paths - - Useful for identifying opportunities for parallel optimization - - Args: - node_idx: Index of the node to check - - Returns: - True if the node has at least one output consumed by multiple nodes, - False otherwise or if node is not in root region. - - Example: - >>> # Node 10 outputs tensor "X" consumed by nodes 11 and 12 - >>> _is_node_divergent(10) # Returns True - """ + """Check if a node has outputs that branch to multiple consumers.""" if node_idx not in self.root.get_nodes(): logger.debug(f"Node {node_idx} not in root region") return False @@ -245,39 +92,7 @@ def _is_node_divergent(self, node_idx: int) -> bool: def _compute_forward_reachable_nodes( self, start_node_idx: int, max_steps: int ) -> dict[int, int]: - """Compute all nodes reachable forward from a starting node with distances. - - Uses breadth-first search (BFS) to find all nodes reachable by following - forward edges (data flow direction) from the start node, up to a maximum - distance. Records the shortest-path distance to each reachable node. - - **Algorithm:** - 1. Initialize with start node at distance 0 - 2. For each node in queue: - - If at max distance, skip - - For each output tensor: - - For each consumer of that tensor: - - If not yet visited, add to queue with distance+1 - - **Use Cases:** - - Convergence detection: Find where branches rejoin - - Region size estimation: Count nodes in forward cone - - Dependency analysis: Understand downstream impact - - Args: - start_node_idx: Index of node to start search from - max_steps: Maximum forward distance to explore - - Returns: - Dictionary mapping reachable node indices to their distances from start. - Includes start_node_idx mapped to distance 0. - - Example: - >>> # Find all nodes within 5 steps forward of node 10 - >>> reachable = _compute_forward_reachable_nodes(10, 5) - >>> reachable[10] # 0 (start node) - >>> reachable[15] # 3 (if node 15 is 3 steps away) - """ + """Compute all nodes reachable forward from a starting node with distances.""" reachable: dict[int, int] = {start_node_idx: 0} queue: deque[tuple[int, int]] = deque([(start_node_idx, 0)]) while queue: @@ -295,35 +110,7 @@ def _compute_forward_reachable_nodes( return reachable def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[int, int]]: - """Pre-compute forward reachability for all nodes in the graph. - - This is a key optimization that enables efficient convergence detection. - By pre-computing forward reachability once, we can quickly answer queries - like "Can node A reach node B?" and "What is the distance from A to B?" - - **Complexity:** - - Time: O(N * (N + E)) where N = nodes, E = edges - - Space: O(N²) in worst case for dense graphs - - **Trade-off:** - Pre-computation takes time upfront but dramatically speeds up convergence - detection, which would otherwise require repeated BFS traversals. - - Args: - max_steps: Maximum forward distance to pre-compute for each node. - Limits both time and space complexity. - - Returns: - Nested dictionary where outer key is start node index, inner key is - reachable node index, and value is shortest-path distance. - - Example: - >>> map = _build_forward_reachable_nodes_map(10) - >>> map[5][8] # Distance from node 5 to node 8 - 3 - >>> 12 in map[5] # Can node 5 reach node 12? - True - """ + """Pre-compute forward reachability for all nodes in the graph.""" logger.debug(f"Building forward reachability map (max_steps={max_steps})...") forward_reachable_nodes_map: dict[int, dict[int, int]] = {} for node_idx in self.root.get_nodes(): @@ -337,43 +124,7 @@ def _build_forward_reachable_nodes_map(self, max_steps: int) -> dict[int, dict[i return forward_reachable_nodes_map def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: - """Find convergence point and intermediate nodes for a divergent node. - - Given a divergent node (where computation branches), this method finds: - 1. The convergence node: Where the branches rejoin - 2. All nodes between divergence and convergence - - **Algorithm:** - 1. Identify all branches from the divergent node - 2. Find nodes reachable from all branches (common nodes) - 3. Select nearest common node that forms a valid region - 4. Compute all nodes between divergence and convergence - - **Convergence Criteria:** - A valid convergence node must: - - Be reachable from all branches - - Form a contiguous region (no nodes escape the region) - - Be the nearest such node (minimize region size) - - **Region Validity:** - A region is valid if all nodes within it either stay in the region - or directly reach the convergence point. No node should reach outside - the region before reaching the convergence point. - - Args: - node_idx: Index of the divergent node to find convergence for - - Returns: - Tuple of (converge_node_idx, visited_nodes): - - converge_node_idx: Index of convergence node, or None if not found - - visited_nodes: Set of node indices between divergence and convergence - - Example: - >>> # Node 10 branches to 11 and 12, which rejoin at node 15 - >>> converge_idx, visited = _find_converge_nodes(10) - >>> converge_idx # 15 - >>> visited # {10, 11, 12, 13, 14} (all nodes in between) - """ + """Find convergence point and intermediate nodes for a divergent node.""" node = self.graph.nodes[node_idx] logger.debug(f"Finding convergence for node {node_idx} ({node.op})") @@ -382,13 +133,7 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: if output.name in self.tensor_users_map: branches.extend(self.tensor_users_map[output.name]) - seen: set[int] = set() - unique_branches: list[int] = [] - for branch_idx in branches: - if branch_idx not in seen: - seen.add(branch_idx) - unique_branches.append(branch_idx) - branches = unique_branches + branches = list(dict.fromkeys(branches)) logger.debug(f" {len(branches)} unique branches found") @@ -397,30 +142,15 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: logger.debug(" Insufficient branches for convergence") return None, set() - # ===================================================================== # STEP 1: Find Common Reachable Nodes (Potential Convergence Points) - # ===================================================================== - # A valid convergence node must be reachable from ALL branches. - # Use pre-computed forward reachability for efficiency. - - # Collect forward-reachable nodes for each branch - branch_reachable: list[dict[int, int]] = [] - for branch_idx in branches: - reachable = self.forward_reachable_nodes_map.get(branch_idx, {}) - branch_reachable.append(reachable) + branch_reachable = [self.forward_reachable_nodes_map.get(b, {}) for b in branches] if not branch_reachable: logger.debug(" No reachable nodes from branches") return None, set() - # Find intersection: nodes reachable from ALL branches - # These are the only candidates for convergence points - common_nodes = set(branch_reachable[0].keys()) - for reachable in branch_reachable[1:]: - common_nodes.intersection_update(reachable.keys()) - + common_nodes = set.intersection(*[set(r.keys()) for r in branch_reachable]) logger.debug(f" {len(common_nodes)} common nodes found") - # Remove the divergent node itself (not a convergence point) common_nodes.discard(node_idx) @@ -428,64 +158,31 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: logger.debug(" No valid convergence candidates") return None, set() - # ===================================================================== # STEP 2: Select Best Convergence Node with Region Validity Check - # ===================================================================== - # Not all common nodes make good convergence points. We need to ensure - # the region formed is "valid" - i.e., contiguous with no escaping edges. - # - # Region validity criterion: - # For every node R in the region (between divergence and candidate): - # For every node T reachable from R: - # If T is outside the region: - # T must be at least as far from R as the candidate is - # (i.e., R doesn't "escape" before reaching candidate) - converge_node_idx: int | None = None min_max_distance = float("inf") - # Get all nodes reachable from the divergent node reachable_from_start = self.forward_reachable_nodes_map.get(node_idx, {}) # Evaluate each candidate convergence point for candidate_idx in common_nodes: - # --------------------------------------------------------------- # Define the potential region: nodes between start and candidate - # --------------------------------------------------------------- - # Region = nodes reachable from start BUT NOT reachable from candidate - # (candidate acts as the boundary) region_nodes: set[int] = set() region_nodes.update(set(reachable_from_start.keys())) reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) - # Remove nodes beyond the candidate (not in our region) region_nodes.difference_update(set(reachable_from_candidate.keys())) - # --------------------------------------------------------------- - # Validate region: Check for "escaping" edges - # --------------------------------------------------------------- - # A region is invalid if any node inside can reach a node outside - # BEFORE reaching the convergence point. This would mean the region - # has edges that "leak out" and isn't properly bounded. broken_region = False - - # Check each node in the proposed region for rnode_index in region_nodes: - # Get all nodes reachable from this region node reachable_from_rnode = self.forward_reachable_nodes_map.get(rnode_index, {}) - - # Distance from this node to the candidate (convergence) rnode_to_candidate_distance = reachable_from_rnode.get(candidate_idx, float("inf")) - - # Check all nodes reachable from this region node for test_node_idx in reachable_from_rnode: # Skip nodes that are inside the region (they're fine) if test_node_idx in region_nodes: continue - # test_node is OUTSIDE the region. Check if it's "escaping" # An escaping edge: region_node reaches test_node BEFORE candidate rnode_to_test_distance = reachable_from_rnode.get(test_node_idx, float("inf")) - # If either distance is infinite, region is broken # (indicates disconnected components or unreachable convergence) if rnode_to_test_distance == float( @@ -493,31 +190,22 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: ) or rnode_to_candidate_distance == float("inf"): broken_region = True break - # If test_node is closer than candidate, we have an escape! # This means computation flows OUT of region before converging if rnode_to_test_distance < rnode_to_candidate_distance: broken_region = True break - if broken_region: break - # Skip this candidate if region is invalid if broken_region: continue - - # --------------------------------------------------------------- # Valid candidate! Check if it's the nearest one - # --------------------------------------------------------------- - # We want the closest convergence point to minimize region size - # "Distance" = maximum distance from any branch to convergence max_distance = max(reachable[candidate_idx] for reachable in branch_reachable) if max_distance < min_max_distance: min_max_distance = max_distance converge_node_idx = candidate_idx - # If no valid convergence found, this divergence has no convergence if converge_node_idx is None: logger.debug(" No valid convergence found") @@ -527,64 +215,19 @@ def _find_converge_nodes(self, node_idx: int) -> tuple[int | None, set[int]]: logger.debug( f" Convergence at node {converge_node_idx} ({converge_node.op}), distance {min_max_distance}" ) - - # ===================================================================== # STEP 3: Compute All Nodes Between Divergence and Convergence - # ===================================================================== - # Now that we have a valid convergence point, we need to identify ALL - # nodes that should be included in the convergence region. - # - # A node is "between" divergence and convergence if: - # 1. It's reachable from the divergence node (on some path from divergence) - # 2. The convergence node is reachable from it (on some path to convergence) - # 3. It's not the convergence node itself (convergence is the boundary) - # - # This captures all the "interior" nodes of the funnel/diamond pattern, - # including all branches and intermediate computations. - visited_nodes: set[int] = set() - - # Check each node reachable from the divergent node for candidate_idx in reachable_from_start: - # Skip the convergence node itself (it's the boundary, not interior) if candidate_idx == converge_node_idx: continue - - # Check if this node can reach the convergence node - # If yes, it's on a path from divergence to convergence reachable_from_candidate = self.forward_reachable_nodes_map.get(candidate_idx, {}) if converge_node_idx in reachable_from_candidate: - # This node is between divergence and convergence! visited_nodes.add(candidate_idx) - logger.debug(f" {len(visited_nodes)} nodes between divergence and convergence") return converge_node_idx, visited_nodes def _max_distance_to_nodes(self, src_idx: int, dst_indices: set[int]) -> int: - """Compute maximum distance from a source node to a set of destination nodes. - - Uses pre-computed forward reachability to efficiently find the maximum - shortest-path distance from src_idx to any node in dst_indices. - - **Use Cases:** - - Determine if a convergence region is within acceptable size limits - - Measure the "spread" of nodes in a potential region - - Validate region compactness constraints - - Args: - src_idx: Source node index - dst_indices: Set of destination node indices - - Returns: - Maximum distance from src to any node in dst_indices. - Returns 0 if dst_indices is empty or no nodes are reachable. - - Example: - >>> # Check if all nodes are within 10 steps - >>> max_dist = _max_distance_to_nodes(start_node, candidate_nodes) - >>> if max_dist <= 10: - ... # Region is compact enough - """ + """Compute maximum distance from a source node to a set of destination nodes.""" max_distance = 0 for dst_idx in dst_indices: reachable = self.forward_reachable_nodes_map.get(src_idx, {}) @@ -599,172 +242,102 @@ def _max_distance_to_nodes(self, src_idx: int, dst_indices: set[int]) -> int: def compute_region_boundaries(self, region: Region, include_constant: bool = False) -> None: """Compute input and output tensor boundaries for a region. - **Algorithm:** - 1. Collect all tensors consumed by region nodes (potential inputs) - 2. Collect all tensors produced by region nodes (potential outputs) - 3. Input = consumed tensors NOT produced by region nodes - 4. Output = produced tensors consumed by nodes OUTSIDE the region - - This accurately captures the data flow boundaries of the region. - Args: region: The region to compute boundaries for + include_constant: Whether to include constant tensors in inputs """ - node_indices = region.get_all_nodes_recursive() - all_inputs: set[str] = set() - all_outputs: set[str] = set() - internal_tensors: set[str] = set() + node_indices = region.get_region_nodes_and_descendants() + + consumed_tensors: set[str] = set() + produced_tensors: set[str] = set() + region_outputs: set[str] = set() - # First pass: collect all inputs and outputs for node_idx in node_indices: if node_idx >= len(self.graph.nodes): continue node = self.graph.nodes[node_idx] - # Collect input tensors + + # Collect consumed tensors (potential inputs) for input_tensor in node.inputs: if isinstance(input_tensor, gs.Constant) and not include_constant: continue - all_inputs.add(input_tensor.name) - # Collect output tensors - for output_tensor in node.outputs: - all_outputs.add(output_tensor.name) - internal_tensors.add(output_tensor.name) - - # Region inputs = consumed tensors not produced internally - region_inputs = all_inputs - internal_tensors + consumed_tensors.add(input_tensor.name) - # Region outputs = produced tensors consumed externally - region_outputs: set[str] = set() - for node_idx in node_indices: - if node_idx >= len(self.graph.nodes): - continue - node = self.graph.nodes[node_idx] + # Collect produced tensors and determine outputs for output_tensor in node.outputs: tensor_name = output_tensor.name - if tensor_name not in self.tensor_users_map: - region_outputs.add(tensor_name) - continue - # Check if any consumer is outside the region - has_external_consumer = False - # Get consumer nodes from tensor_users_map - consumer_indices = self.tensor_users_map[tensor_name] - for consumer_idx in consumer_indices: - if consumer_idx not in node_indices: - # Consumer is outside the region - has_external_consumer = True - break - if has_external_consumer: - region_outputs.add(tensor_name) - # Also check if this is a graph output - if output_tensor in self.graph.outputs: + produced_tensors.add(tensor_name) + + consumer_indices = self.tensor_users_map.get(tensor_name, []) + has_external_consumer = any(idx not in node_indices for idx in consumer_indices) + is_graph_output = output_tensor in self.graph.outputs + + if has_external_consumer or is_graph_output or not consumer_indices: region_outputs.add(tensor_name) - # Add to region - region.inputs = sorted(region_inputs) + # Region inputs = consumed tensors not produced internally + region.inputs = sorted(consumed_tensors - produced_tensors) region.outputs = sorted(region_outputs) logger.debug( - f"Computed boundaries: {len(region_inputs)} inputs, {len(region_outputs)} outputs" + f"Computed boundaries: {len(region.inputs)} inputs, {len(region.outputs)} outputs" ) def print_tree( self, region: Region | None = None, indent: int = 0, - max_nodes_to_show: int = DEFAULT_MAX_NODES_TO_SHOW, + max_items: int = DEFAULT_MAX_NODES_TO_SHOW, file=None, ) -> None: - """Print hierarchical region tree in human-readable text format. - - Recursively prints the region hierarchy with indentation showing depth. - For each region, displays: - - ID, level, and type (LEAF/COMPOSITE/ROOT) - - Node counts (direct and recursive) - - I/O tensor counts - - Sample of nodes in the region (up to max_nodes_to_show) - - Child regions (recursively) - - Args: - region: Region to print (None defaults to root) - indent: Current indentation level (0 = root) - max_nodes_to_show: Maximum nodes to display per region (default: 5) - file: Output file object (None defaults to stdout) - - Example: - >>> builder.print_tree() - ├─ Region 0 (Level 0, Type: ROOT) - │ ├─ Direct nodes: 0 - │ └─ Children: 2 - │ ├─ Region 1 (Level 1, Type: COMPOSITE) - ... - """ + """Print hierarchical region tree in human-readable text format.""" region = region or self.root - if region is None: - return - if file is None: - file = sys.stdout + file = file or sys.stdout + p = " " * indent - prefix = " " * indent + def print_items(items, label, formatter=str): + """Print a truncated list of items.""" + items = list(items) + print(f"{p}│ ├─ {label}: {len(items)}", file=file) + for item in items[:max_items]: + print(f"{p}│ │ - {formatter(item)}", file=file) + if len(items) > max_items: + print(f"{p}│ │ ... and {len(items) - max_items} more", file=file) - # Print region header - region_type = region.get_type().value + # Header print( - f"{prefix}├─ Region {region.get_id()} (Level {region.get_level()}, Type: {region_type})", + f"{p}├─ Region {region.id} (Level {region.level}, Type: {region.type.value})", file=file, ) - # Print region size info + # Counts direct_nodes = region.get_nodes() - total_nodes = region.get_all_nodes_recursive() - num_children = len(region.get_children()) - - print(f"{prefix}│ ├─ Direct nodes: {len(direct_nodes)}", file=file) - print(f"{prefix}│ ├─ Total nodes (recursive): {len(total_nodes)}", file=file) - print(f"{prefix}│ ├─ Children: {num_children}", file=file) - - # Print region I/O info - inputs = region.get_inputs() - outputs = region.get_outputs() - print(f"{prefix}│ ├─ Inputs: {len(inputs)} tensors", file=file) - if inputs: - for tensor_name in list(inputs)[:max_nodes_to_show]: - print(f"{prefix}│ │ - {tensor_name}", file=file) - if len(inputs) > max_nodes_to_show: - print(f"{prefix}│ │ ... and {len(inputs) - max_nodes_to_show} more", file=file) - print(f"{prefix}│ └─ Outputs: {len(outputs)} tensors", file=file) - if outputs: - for tensor_name in list(outputs)[:max_nodes_to_show]: - print(f"{prefix}│ - {tensor_name}", file=file) - if len(outputs) > max_nodes_to_show: - print(f"{prefix}│ ... and {len(outputs) - max_nodes_to_show} more", file=file) - - # Print direct nodes in this region (if any) + children = region.get_children() + print(f"{p}│ ├─ Direct nodes: {len(direct_nodes)}", file=file) + print(f"{p}│ ├─ Total nodes: {len(region.get_region_nodes_and_descendants())}", file=file) + print(f"{p}│ ├─ Children: {len(children)}", file=file) + + # I/O + print_items(region.inputs, "Inputs") + print_items(region.outputs, "Outputs") + + # Direct nodes if direct_nodes: - print(f"{prefix}│", file=file) - print(f"{prefix}│ Nodes in this region:", file=file) - nodes_list = sorted(direct_nodes)[:max_nodes_to_show] - for node_idx in nodes_list: + print(f"{p}│\n{p}│ Nodes in this region:", file=file) + for node_idx in sorted(direct_nodes)[:max_items]: if node_idx < len(self.graph.nodes): node = self.graph.nodes[node_idx] - print( - f"{prefix}│ - Node {node_idx}: {node.op} (name: {node.name})", file=file - ) - - if len(direct_nodes) > max_nodes_to_show: - print( - f"{prefix}│ ... and {len(direct_nodes) - max_nodes_to_show} more nodes", - file=file, - ) + print(f"{p}│ - Node {node_idx}: {node.op} ({node.name})", file=file) + if len(direct_nodes) > max_items: + print(f"{p}│ ... and {len(direct_nodes) - max_items} more", file=file) - # Print children (recursively) - children = region.get_children() + # Children if children: - print(f"{prefix}│", file=file) - print(f"{prefix}│ Child regions:", file=file) - for child_index, child in enumerate(children): - print(f"{prefix}│", file=file) - self.print_tree(child, indent + 1, max_nodes_to_show, file) + print(f"{p}│\n{p}│ Child regions:", file=file) + for child in children: + print(f"{p}│", file=file) + self.print_tree(child, indent + 1, max_items, file) class RegionPartitioner(RegionSearchBase): @@ -773,125 +346,28 @@ class RegionPartitioner(RegionSearchBase): This class implements Phase 1 of the combined region search strategy. It performs a systematic traversal of the computation graph from inputs to outputs, identifying natural boundaries for region formation based on computation flow patterns. - - **Core Strategy:** - Partitions the graph by analyzing three types of computational patterns: - - 1. **Divergent Nodes with Convergence:** - - Nodes whose outputs branch to multiple paths (divergence) - - Paths that eventually rejoin at a common node (convergence) - - Creates a single region encompassing divergence + branches + convergence - - Example: A → (B,C) → D creates region containing {A, B, C, D} - - 2. **Divergent Nodes without Convergence:** - - Nodes whose outputs branch but never rejoin - - Creates a single-node "orphan" region for the divergent node - - Example: A → (B,C) with no convergence creates region {A} - - 3. **Linear Sequences:** - - Chains of non-divergent nodes (simple sequential computation) - - Groups entire sequence into one region - - Example: A → B → C → D creates region {A, B, C, D} - - **Algorithm Overview:** - ``` - For each node in graph order: - If already visited: skip - If divergent: - Find convergence point - If convergence exists within threshold: - Create region with all nodes between divergence and convergence - Else: - Create single-node region (orphan) - Else (non-divergent): - Build sequence: follow chain until hitting divergent node - Create region containing entire sequence - ``` - - **Key Features:** - - **Complete Coverage:** Every node is assigned to exactly one region - - **Convergence Detection:** Uses pre-computed reachability for efficiency - - **Distance Threshold:** Limits region size to DEFAULT_MAX_STEPS - - **Sequential Processing:** Respects data flow order for natural groupings - - **Region Types Created:** - All regions created by this class are LEAF regions (level 0). Higher-level - structure is created later by TopDownRegionBuilder. - - **State Management:** - - **visited_nodes:** Tracks which nodes have been assigned to regions - - **current_region:** Region being built (commit when complete) - - **regions:** List of completed regions - - **current_region_id:** Counter for unique region IDs - - **Output:** - A list of LEAF regions that partition the entire graph. These regions - serve as input to Phase 2 (TopDownRegionBuilder) for refinement. - - **Example:** - ```python - partitioner = RegionPartitioner(graph) - initial_regions = partitioner.partition_graph() - - # Analyze results - print(f"Created {len(initial_regions)} regions") - print(f"Covered {len(partitioner.visited_nodes)} / {len(graph.nodes)} nodes") - - # Typical output for a ResNet layer: - # - Conv node → orphan region (diverges to BN and skip path) - # - BN → ReLU sequence → sequential region - # - Add (convergence) → orphan or part of next sequence - ``` - - **Performance:** - - Time: O(N) where N = number of nodes (each visited once) - - Space: O(N) for visited_nodes set and region storage - - Attributes: - regions: List of completed LEAF regions - current_region: Region currently being built (None if between regions) - current_region_id: Counter for assigning unique region IDs - visited_nodes: Set of node indices already assigned to regions - - See Also: - TopDownRegionBuilder: Phase 2 refinement of partitioner output - CombinedRegionSearch: Orchestrates both phases """ - def __init__(self, graph: gs.Graph): - """Initialize the partitioner with a computation graph. - - Sets up necessary data structures and inherits graph analysis utilities - from RegionSearchBase (tensor users map, reachability, etc.). - - Args: - graph: The ONNX graph to partition (onnx_graphsurgeon.Graph) - """ - super().__init__(graph, root=None) + def __init__( + self, + graph: gs.Graph, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, + ): + """Initialize the partitioner with a computation graph.""" + super().__init__( + graph, + root=None, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) self.regions: list[Region] = [] self.current_region: Region | None = None self.current_region_id: int = 0 self.visited_nodes: set[int] = set() def _append_node_to_region(self, node_idx: int): - """Add a node to the current region, creating a new region if needed. - - This is the primary method for building regions incrementally. If no - region is currently active, creates a new LEAF region. Then adds the - specified node to that region. - - **Usage Pattern:** - Typically called multiple times to build up a region, then followed - by _commit_region() to finalize and store the completed region. - - Args: - node_idx: Index of node to add to current region - - Side Effects: - - Creates new region if current_region is None - - Increments current_region_id when creating new region - - Adds node to current_region - """ + """Add a node to the current region, creating a new region if needed.""" node = self.graph.nodes[node_idx] if self.current_region is None: self.current_region = Region( @@ -902,37 +378,15 @@ def _append_node_to_region(self, node_idx: int): self.current_region.add_node(node_idx) logger.debug( - f" Added node {node_idx} ({node.op}), region size: {self.current_region.get_size()}" + f" Added node {node_idx} ({node.op}), region size: {len(self.current_region.nodes)}" ) def _commit_region(self): - """Finalize and store the current region being built. - - Completes region construction by: - 1. Computing input/output tensor boundaries - 2. Adding region to the completed regions list - 3. Resetting current_region to None for next region - - **Boundary Computation:** - Determines which tensors flow into and out of the region based on - which nodes produce/consume them. This is essential for understanding - region dependencies. - - **Post-Conditions:** - - current_region is added to regions list - - current_region is reset to None - - Region has computed input/output tensor lists - - Side Effects: - - Appends current_region to self.regions - - Sets current_region to None - - Logs region commit with size info - """ + """Finalize and store the current region being built.""" if self.current_region is not None: - region_size = self.current_region.get_size() + region_size = len(self.current_region.nodes) region_id = self.current_region.id - # Compute input/output tensor boundaries self.compute_region_boundaries(self.current_region) self.regions.append(self.current_region) @@ -944,75 +398,29 @@ def _commit_region(self): logger.debug("No region to commit") def _build_sequence_from_node(self, node_idx: int, max_nodes: int = -1): - """Build a region from a linear sequence of non-divergent nodes. - - Starting from a non-divergent node, follows the forward chain of nodes, - adding each non-divergent node to the current region. Stops when hitting: - - A divergent node (branches to multiple paths) - - A node already visited - - End of graph - - **Algorithm:** - ``` - queue = [start_node] - while queue not empty: - node = dequeue() - if node is divergent: - stop (this node will be handled separately) - else: - add node to region - add all successors to queue - commit region - ``` - - **Example:** - For graph: Conv → BN → ReLU → MaxPool (no branching) - Creates one region containing all four nodes. - - **Stopping Conditions:** - - Divergent node encountered (boundary for this region) - - All successors already visited - - No more forward connections - - Args: - node_idx: Index of starting node (must be non-divergent) - - Side Effects: - - Adds nodes to current_region via _append_node_to_region - - Marks nodes as visited - - Commits completed region - - Note: - Always commits the region at the end, even if only one node was added. - """ - start_node = self.graph.nodes[node_idx] - logger.debug(f"Building sequence from node {node_idx} ({start_node.op})") + """Build a region from a linear sequence of nodes.""" + logger.debug(f"Building sequence from node {node_idx} ({self.graph.nodes[node_idx].op})") queue: deque[int] = deque([node_idx]) nodes_added = 0 - while len(queue) > 0: - current_node_idx = queue.popleft() - current_node = self.graph.nodes[current_node_idx] + while queue: + current_idx = queue.popleft() + node = self.graph.nodes[current_idx] - if not self._is_node_divergent(current_node_idx): - self._append_node_to_region(current_node_idx) - self.visited_nodes.add(current_node_idx) - nodes_added += 1 - - # Find successors - successor_count = 0 - for output_tensor in current_node.outputs: - if output_tensor.name in self.tensor_users_map: - successors = self.tensor_users_map[output_tensor.name] - successor_count += len(successors) - queue.extend(successors) + self._append_node_to_region(current_idx) + self.visited_nodes.add(current_idx) + nodes_added += 1 + + if self._is_node_divergent(current_idx): + logger.debug(f" Stopped at divergent node {current_idx} ({node.op})") else: - self._append_node_to_region(current_node_idx) - nodes_added += 1 - logger.debug(f" Stopped at divergent node {current_node_idx} ({current_node.op})") + # Queue successors for non-divergent nodes + for output in node.outputs: + if output.name in self.tensor_users_map: + queue.extend(self.tensor_users_map[output.name]) - if max_nodes > 0 and nodes_added >= max_nodes: + if 0 < max_nodes <= nodes_added: logger.debug(" Max nodes reached") break @@ -1039,22 +447,6 @@ def _build_small_converged_region( \\ / convergence ``` - - **Example:** - For ResNet skip connection: - - start_node: Output of previous layer (branches) - - visited_nodes: {Conv, BN, ReLU, Conv, BN} (main path) - - converge_node: Add operation (merges with skip) - - Args: - start_node_idx: The divergent node where branches begin - converge_node_idx: Where branches rejoin (currently unused but kept for API) - visited_nodes: All nodes between divergence and convergence - - Side Effects: - - Adds all nodes to current region - - Marks all nodes as visited - - Commits the completed region """ visited_nodes.remove(start_node_idx) for node_idx in sorted(visited_nodes): @@ -1063,59 +455,10 @@ def _build_small_converged_region( if not self._is_node_divergent(converge_node_idx): self._append_node_to_region(converge_node_idx) self.visited_nodes.add(converge_node_idx) - self._build_sequence_from_node(converge_node_idx, max_nodes=3) + self._build_sequence_from_node(converge_node_idx, max_nodes=MAX_PROBE_STEPS_AFTER_CONVERGE) def _build_region_from_node(self, node_idx: int): - """Process a single node and create appropriate region(s) based on its pattern. - - This is the core dispatch method that determines how to handle each node - based on whether it's divergent (branches) or sequential. Implements the - three pattern recognition strategies described in the class documentation. - - **Decision Logic:** - ``` - If node already visited: - Skip (already in a region) - Else if node is divergent: - Try to find convergence point - If convergence found within distance threshold: - Create convergence region (divergence + branches + convergence) - Else: - Create orphan region (just the divergent node) - Else (non-divergent): - Build sequence region (follow chain until divergence) - ``` - - **Pattern 1: Divergent with Convergence (Ideal Case)** - Creates a complete "funnel" region capturing parallel branches: - - Example: ResNet skip connection (Conv branch + identity → Add) - - Condition: converge_node found AND distance < DEFAULT_MAX_STEPS - - Result: One region containing all nodes between divergence and convergence - - **Pattern 2: Divergent without Convergence (Boundary Case)** - Creates a single-node "orphan" region: - - Example: Final layer that branches to multiple outputs - - Condition: No convergence found OR convergence too far away - - Result: Region containing only the divergent node - - **Pattern 3: Sequential Chain (Common Case)** - Creates a region containing linear sequence: - - Example: Conv → BN → ReLU → MaxPool - - Condition: Node is not divergent - - Result: Region containing the full non-divergent chain - - Args: - node_idx: Index of node to process - - Side Effects: - - Marks processed nodes as visited - - Creates and commits region(s) via helper methods - - May recursively process successor nodes (in sequence building) - - Note: - This method is idempotent - calling it multiple times on the same - node has no effect after the first call (due to visited check). - """ + """Process a single node and create appropriate region(s) based on its pattern.""" node = self.graph.nodes[node_idx] # Skip nodes already assigned to regions @@ -1128,13 +471,10 @@ def _build_region_from_node(self, node_idx: int): # Pattern 1 & 2: Handle divergent nodes if self._is_node_divergent(node_idx): logger.debug(" Divergent node, searching for convergence") - # Attempt to find where branches rejoin converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) - # Check if convergence creates a reasonable-sized region max_distance = self._max_distance_to_nodes(node_idx, visited_nodes) - # Pattern 1: Convergence found and region size is acceptable if converge_node_idx is not None and max_distance < DEFAULT_MAX_STEPS: converge_node = self.graph.nodes[converge_node_idx] @@ -1153,84 +493,23 @@ def _build_region_from_node(self, node_idx: int): self._append_node_to_region(node_idx) self.visited_nodes.add(node_idx) self._commit_region() - # Pattern 3: Handle non-divergent (sequential) nodes else: + # Pattern 3: Handle non-divergent (sequential) nodes logger.debug(" Non-divergent node, building sequence") # Build region by following the linear chain forward self._build_sequence_from_node(node_idx) self._commit_region() def partition_graph(self): - """Partition the entire graph into non-overlapping LEAF regions. - - This is the main entry point for bottom-up graph partitioning. Performs - a single pass over all nodes in graph order, creating regions based on - divergence/convergence patterns and sequential chains. - - **Algorithm:** - ``` - For each node in graph (in index order): - If node not yet visited: - Analyze node type (divergent vs sequential) - Create appropriate region(s) for node and its neighborhood - Mark processed nodes as visited - - Result: Complete partitioning where every node belongs to exactly one region - ``` - - **Processing Order:** - Nodes are processed in index order (typically matches graph construction - order / topological-ish order). This tends to group naturally related - operations together. - - **Completeness Guarantee:** - Every node in the graph will be assigned to exactly one region. The - visited_nodes set ensures no node is processed twice, and the loop over - all indices ensures no node is skipped. - - **Region Types Created:** - - Convergence regions: Divergent node + branches + convergence - - Orphan regions: Single divergent node with no close convergence - - Sequence regions: Linear chains of non-divergent nodes - - **Output Quality:** - - Total regions: Typically 10-30% of total nodes (varies by graph) - - Region sizes: Mix of small (1-3 nodes) and medium (5-15 nodes) - - Coverage: 100% of graph nodes - - Returns: - List of LEAF regions that partition the entire graph. - Each node appears in exactly one region. - Regions are stored in self.regions and also returned. - - Side Effects: - - Populates self.regions with created regions - - Populates self.visited_nodes with all node indices - - Logs progress and statistics - - Example: - >>> partitioner = RegionPartitioner(graph) - >>> regions = partitioner.partition_graph() - >>> # Verify complete coverage - >>> all_nodes = set() - >>> for region in regions: - ... all_nodes.update(region.get_nodes()) - >>> assert all_nodes == set(range(len(graph.nodes))) - - Performance: - - Time: O(N) where N = number of nodes (each visited once) - - Space: O(N) for visited set and region storage - """ + """Partition the entire graph into non-overlapping LEAF regions.""" logger.info(f"Partitioning graph ({len(self.graph.nodes)} nodes)") logger.debug( f"Initial state: {len(self.visited_nodes)} visited, {len(self.regions)} regions" ) - # Main partitioning loop: process each node in graph order for node_idx in range(len(self.graph.nodes)): self._build_region_from_node(node_idx) - # Log completion and coverage statistics coverage_pct = ( 100 * len(self.visited_nodes) / len(self.graph.nodes) if self.graph.nodes else 0 ) @@ -1239,9 +518,8 @@ def partition_graph(self): f"{len(self.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" ) - # Log summary statistics about region sizes if self.regions: - region_sizes = [r.get_size() for r in self.regions] + region_sizes = [len(r.nodes) for r in self.regions] avg_size = sum(region_sizes) / len(region_sizes) min_size = min(region_sizes) max_size = max(region_sizes) @@ -1258,103 +536,6 @@ class TopDownRegionBuilder(RegionSearchBase): 1. Identifying and merging converged sub-patterns 2. Splitting long sequences into optimal sub-regions 3. Creating a hierarchical COMPOSITE region structure - - **Core Strategy:** - Starting with a flat LEAF region, creates a hierarchy by: - - **Step 1: Merge Converged Regions** - - Identifies divergent nodes within the region - - Finds their convergence points - - Groups divergence+branches+convergence into sub-regions - - Leaves remaining nodes for sequence splitting - - **Step 2: Split Sequence Regions** - - Takes ungrouped nodes (not part of converged patterns) - - Splits into individual node regions initially - - Merges adjacent nodes if they form producer-consumer chains - - Avoids merging boundary operations (Conv, Gemm, etc.) - - Limits region size to prevent overly large groups - - **Step 3: Create Composite** - - Wraps all sub-regions into a single COMPOSITE region - - Computes hierarchical input/output boundaries - - Returns refined region with better internal structure - - **Merging Criteria for Sequences:** - Two adjacent sequence regions can merge if ALL of: - - Producer region's outputs go to exactly one region (simple producer→consumer chain) - - Neither region is too large (< maximum_sequence_region_size nodes each) - - Consumer node is not a boundary operation (Conv, Gemm, etc.) - - Regions are adjacent in data flow (no gaps) - - **Boundary Operations:** - These operation types are treated as boundaries (don't merge across them): - - Conv, ConvTranspose: Convolution layers - - Gemm, MatMul: Matrix multiplications - - AveragePool, MaxPool, GlobalAveragePool, GlobalMaxPool: Pooling - - Resize: Spatial resizing - - **Example Transformation:** - ``` - Input (flat LEAF region): - [Conv, BN, ReLU, Split, Path1_A, Path1_B, Path2_A, Path2_B, Concat] - - Output (hierarchical COMPOSITE region): - COMPOSITE { - LEAF {Conv}, # Boundary op stays alone - LEAF {BN, ReLU}, # Sequence merged - LEAF {Split}, # Divergent node - LEAF {Path1_A, Path1_B, Path2_A, Path2_B, Concat}, # Converged pattern - } - ``` - - **Key Features:** - - **Hierarchical Structure:** Creates parent-child region relationships - - **Pattern-Aware:** Recognizes convergence and sequence patterns - - **Size-Bounded:** Limits region sizes for optimal granularity - - **Boundary-Aware:** Respects operation type boundaries - - **Inputs:** - - A LEAF region from RegionPartitioner (flat list of nodes) - - The graph structure - - Starting region ID for new regions - - **Output:** - - A COMPOSITE region containing LEAF child regions - - Better internal structure reflecting computation patterns - - Same total nodes, but organized hierarchically - - **Usage Pattern:** - ```python - # After partitioning - initial_region = partitioner.regions[0] - - # Refine structure - builder = TopDownRegionBuilder(graph, initial_region, next_region_id=10) - refined_region = builder.build_composite_region() - - # refined_region now has hierarchical structure - print(f"Children: {len(refined_region.get_children())}") - for child in refined_region.get_children(): - print(f" Child {child.get_id()}: {child.get_size()} nodes") - ``` - - **Performance:** - - Time: O(N + E) where N = nodes in region, E = edges between them - - Space: O(N) for temporary data structures - - Attributes: - graph: The computation graph - root: Input region to refine (becomes search space) - regions: Output list of refined regions (typically one COMPOSITE) - next_region_id: Counter for assigning unique IDs to new regions - boundary_op_types: Set of operation types treated as boundaries - maximum_sequence_region_size: Maximum number of nodes allowed in a sequence region - during merging. Prevents overly large regions (default: 10) - - See Also: - RegionPartitioner: Creates initial regions for refinement - CombinedRegionSearch: Orchestrates partitioning and refinement """ def __init__( @@ -1363,16 +544,16 @@ def __init__( root: Region, next_region_id: int = 0, maximum_sequence_region_size: int = 10, + tensor_users_map: dict[str, list[int]] | None = None, + forward_reachable_nodes_map: dict[int, dict[int, int]] | None = None, ): - """Initialize the refiner with a region to refine. - - Args: - graph: The ONNX graph (onnx_graphsurgeon.Graph) - root: The region to refine (typically from RegionPartitioner) - next_region_id: Starting ID for new regions created during refinement - maximum_sequence_region_size: Maximum nodes per sequence region during merging (default: 10) - """ - super().__init__(graph, root=root) + """Initialize the refiner with a region to refine.""" + super().__init__( + graph, + root=root, + tensor_users_map=tensor_users_map, + forward_reachable_nodes_map=forward_reachable_nodes_map, + ) self.regions: list[Region] = [] self.next_region_id = next_region_id self.maximum_sequence_region_size = maximum_sequence_region_size @@ -1389,23 +570,7 @@ def __init__( } def _create_leaf_region(self, node_indices: set[int]) -> Region: - """Create a new LEAF region containing specified nodes. - - Helper method to construct a properly configured LEAF region: - - Assigns unique region ID - - Sets level one deeper than root - - Adds all specified nodes - - Computes input/output tensor boundaries - - Args: - node_indices: Set of node indices to include in the region - - Returns: - New LEAF region containing the specified nodes with computed boundaries - - Side Effects: - Increments next_region_id counter - """ + """Create a new LEAF region containing specified nodes.""" region = Region( region_id=self.next_region_id, level=self.root.level + 1, region_type=RegionType.LEAF ) @@ -1416,27 +581,7 @@ def _create_leaf_region(self, node_indices: set[int]) -> Region: return region def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Region]]: - """Build mapping from tensor names to regions that consume them. - - Similar to tensor_users_map but at the region level instead of node level. - This enables efficient traversal of region dependencies for merging decisions. - - **Purpose:** - Used during sequence splitting to identify producer-consumer chains - between regions. If a tensor is consumed by only one region, that - region might be mergeable with its producer. - - Args: - regions: List of regions to analyze - - Returns: - Dictionary mapping tensor names to lists of regions that consume them. - Tensors with len(consumers) == 1 indicate potential merge opportunities. - - Example: - >>> # Tensor "X" consumed by region 5 and region 7 - >>> usage_map["X"] == [region5, region7] - """ + """Build mapping from tensor names to regions that consume them.""" region_usage_map: dict[str, list[Region]] = {} for region in regions: for tensor_name in region.inputs: @@ -1446,75 +591,11 @@ def _build_region_usage_map(self, regions: list[Region]) -> dict[str, list[Regio return region_usage_map def _split_sequence_regions(self, root: Region) -> list[Region]: - """Split a region into smaller sub-regions by merging producer-consumer chains. - - Takes a region and creates optimal sub-regions by: - 1. Initially splitting into individual single-node regions - 2. Traversing in data flow order (following tensor dependencies) - 3. Merging adjacent regions that form simple producer-consumer chains - 4. Respecting boundary operations and size limits - - **Algorithm:** - ``` - 1. Create one LEAF region per node - 2. Build tensor → consuming regions map - 3. Traverse regions in data flow order (BFS from inputs): - For each region: - Check if all outputs go to single consumer region - If yes and merge criteria met: - Merge this region into consumer region - Mark this region as removed - 4. Return regions not marked as removed - ``` - - **Merge Criteria (ALL must be true):** - - All outputs of producer go to exactly one consumer (simple chain) - - Producer region size < maximum_sequence_region_size (avoid overly large regions) - - Consumer region size < maximum_sequence_region_size (avoid overly large regions) - - If consumer is single-node boundary op (Conv, etc.), don't merge - - Consumer not already removed (merged elsewhere) - - **Boundary Operations:** - Single-node regions containing these ops stay independent: - Conv, ConvTranspose, Gemm, MatMul, Pooling ops, Resize - - **Example:** - ``` - Input nodes: [Conv, BN, ReLU, Add] - - Initial: Region{Conv}, Region{BN}, Region{ReLU}, Region{Add} - - Processing: - - Conv outputs only to BN, but Conv is boundary → don't merge - - BN outputs only to ReLU, both small → merge to Region{BN, ReLU} - - Region{BN,ReLU} outputs only to Add → merge to Region{BN, ReLU, Add} - - Final: Region{Conv}, Region{BN, ReLU, Add} - ``` - - **Purpose:** - Groups simple sequential operations while keeping compute-heavy - operations (Conv, Gemm) as separate regions for optimization targeting. - - Args: - root: Region to split (contains nodes to partition into sub-regions) - - Returns: - List of LEAF regions that partition the root's nodes with better - granularity than one-node-per-region or all-in-one. - - Note: - This is the "bottom" of the top-down strategy - splits fine-grained, - then merges selectively based on data flow patterns. - """ + """Split a region into smaller sub-regions by merging producer-consumer chains.""" result_regions: list[Region] = [] removed_regions: set[int] = set() - # ===================================================================== # PHASE 1: Split into Single-Node Regions - # ===================================================================== - # Start with maximum granularity: one region per node. - # This gives us the most flexibility for selective merging. for node_idx in root.get_nodes(): region = Region( region_id=self.next_region_id, level=root.level + 1, region_type=RegionType.LEAF @@ -1524,79 +605,44 @@ def _split_sequence_regions(self, root: Region) -> list[Region]: result_regions.append(region) self.next_region_id += 1 - # Build map: tensor_name -> regions that consume it - # Enables efficient lookup of producer-consumer relationships region_usage_map = self._build_region_usage_map(result_regions) - # ===================================================================== # PHASE 2: Merge Regions in Data Flow Order - # ===================================================================== - # Traverse regions following data flow (BFS from inputs). - # At each step, check if producer can merge with consumer. - # This creates longer sequences while respecting constraints. - - # Start from root's input tensors and traverse forward - queue = deque(root.get_inputs()) + queue = deque(root.inputs) while len(queue) > 0: tensor_name = queue.popleft() - # Skip tensors not produced by any region in our scope if tensor_name not in region_usage_map: continue - # Process each region consuming this tensor (potential merge targets) consumers = region_usage_map[tensor_name] for consumer in consumers: # Skip regions already merged into others - if consumer.get_id() in removed_regions: + if consumer.id in removed_regions: continue - - # ------------------------------------------------------------- - # Check if this consumer can merge with its downstream region - # ------------------------------------------------------------- # Merging criteria: ALL outputs go to same single region common_use_region = None can_merge = True - # Check all outputs of the consumer region for output_tensor in consumer.outputs: - # Add output to queue for continued traversal queue.append(output_tensor) - - # Check if output has consumers in our region set if output_tensor not in region_usage_map: - # Output goes outside (or nowhere) - can't merge can_merge = False break - - # Get regions consuming this output use_regions = region_usage_map[output_tensor] - - # Must go to exactly ONE region (simple chain) if len(use_regions) != 1: - # Branches to multiple regions - can't merge can_merge = False break - - # Check if all outputs go to the SAME region if common_use_region is None: - # First output: remember its consumer common_use_region = use_regions[0] elif common_use_region != use_regions[0]: - # Different outputs go to different regions - can't merge can_merge = False break - # No valid downstream region to merge with - if common_use_region is None or common_use_region.get_id() in removed_regions: + if common_use_region is None or common_use_region.id in removed_regions: can_merge = False continue - - # ------------------------------------------------------------- - # Apply Additional Constraints - # ------------------------------------------------------------- - # Constraint 1: Limit the number of boundary operations after merge nodes_after_merge = set() nodes_after_merge.update(consumer.get_nodes()) @@ -1605,110 +651,35 @@ def _split_sequence_regions(self, root: Region) -> list[Region]: boundary_op_count = sum( [1 if op in self.boundary_op_types else 0 for op in node_ops] ) - if boundary_op_count > 3: can_merge = False continue - # Constraint 2: Size limits to avoid overly large regions # Keep regions manageable for optimization passes if ( - consumer.get_size() >= self.maximum_sequence_region_size - or common_use_region.get_size() >= self.maximum_sequence_region_size + len(consumer.nodes) >= self.maximum_sequence_region_size + or len(common_use_region.nodes) >= self.maximum_sequence_region_size ): # One or both regions too large - don't merge can_merge = False continue - - # ------------------------------------------------------------- - # Perform Merge - # ------------------------------------------------------------- # All criteria met: merge consumer into its downstream region if can_merge: common_use_region.merge(consumer) - removed_regions.add(consumer.get_id()) - - # ===================================================================== - # PHASE 3: Cleanup and Finalize - # ===================================================================== + removed_regions.add(consumer.id) # Remove regions that were merged into others - result_regions = [ - region for region in result_regions if region.get_id() not in removed_regions - ] - + result_regions = [region for region in result_regions if region.id not in removed_regions] # Recompute boundaries for all remaining regions - # (merging may have changed input/output tensors) for region in result_regions: self.compute_region_boundaries(region) return result_regions def _merge_converged_regions(self, root: Region): - """Identify and merge convergence patterns within a region. - - Traverses the region to find divergent nodes and their convergence points, - creating sub-regions that capture divergence→branches→convergence patterns. - Nodes not part of any convergence pattern are left for sequence splitting. - - **Algorithm:** - ``` - 1. Traverse region in data flow order (BFS from inputs) - 2. For each node: - If node is divergent (branches): - Find convergence point - If convergence exists within root: - Create LEAF region with all nodes between divergence and convergence - Mark those nodes as removed (grouped) - 3. Create LEAF region for remaining ungrouped nodes - 4. Return all created regions - ``` - - **Convergence Detection:** - Uses inherited _find_converge_nodes() to identify where branches rejoin. - Only creates convergence regions if the convergence point is within - the root region being refined. - - **Example:** - ``` - Root contains: [A, B, C, D, E, F, G] - - Graph structure: - A → B (divergent) → C, D - C → E - D → E (convergence) - E → F → G - - Result: - - Region1 {B, C, D, E}: Convergence pattern - - Region2 {A, F, G}: Remaining sequence nodes - ``` - - **Use Case:** - Captures patterns like: - - ResNet skip connections (Conv branch + identity → Add) - - Inception modules (multiple parallel conv paths → Concat) - - Attention mechanisms (Q/K/V branches → attention computation) - - **Limitations:** - - Only finds convergence patterns where convergence is in root region - - Nodes can only belong to one convergence pattern (first match wins) - - Uses intersection with root nodes to ensure boundaries respected - - Args: - root: Region to analyze for convergence patterns - - Returns: - List of LEAF regions: - - Some containing convergence patterns (divergence + branches + convergence) - - One containing remaining nodes not part of any pattern - - Note: - This is the "top" of the top-down strategy - identifies high-level - patterns first, then delegates remaining nodes to sequence splitting. - """ + """Identify and merge convergence patterns within a region.""" result_regions: list[Region] = [] removed_nodes: set[int] = set() - queue = deque(root.get_inputs()) + queue = deque(root.inputs) while len(queue) > 0: tensor_name = queue.popleft() if tensor_name not in self.tensor_users_map: @@ -1729,7 +700,7 @@ def _merge_converged_regions(self, root: Region): if not self._is_node_divergent(node_idx): continue converge_node_idx, visited_nodes = self._find_converge_nodes(node_idx) - visited_nodes = visited_nodes.intersection(root.get_all_nodes_recursive()) + visited_nodes = visited_nodes.intersection(root.get_region_nodes_and_descendants()) # if no convergence found, skip if converge_node_idx is None: continue @@ -1740,7 +711,7 @@ def _merge_converged_regions(self, root: Region): removed_nodes.update(visited_nodes) continue # create a leaf region for the remaining nodes - remaining_nodes = root.get_nodes() - removed_nodes + remaining_nodes = set(root.get_nodes()) - removed_nodes if len(remaining_nodes) > 0: result_regions.append(self._create_leaf_region(remaining_nodes)) # compute region boundaries for all regions @@ -1749,85 +720,7 @@ def _merge_converged_regions(self, root: Region): return result_regions def build_composite_region(self) -> Region: - """Refine a flat region into a hierarchical COMPOSITE region. - - This is the main entry point for top-down refinement. Transforms a flat - LEAF region from RegionPartitioner into a hierarchical structure with - better internal organization. - - **Three-Stage Algorithm:** - - **Stage 1: Merge Converged Patterns** - Identifies divergence→convergence patterns and groups them: - - Finds divergent nodes where computation branches - - Locates convergence points where branches rejoin - - Creates sub-regions for complete convergence patterns - - Leaves ungrouped nodes for next stage - - **Stage 2: Split Sequence Regions** - Takes remaining (ungrouped) nodes and optimizes granularity: - - Splits into fine-grained (single-node) regions - - Merges adjacent regions forming producer-consumer chains - - Respects boundary operations (Conv, Gemm, etc.) - - Limits region sizes to avoid overly large groups - - **Stage 3: Create Composite Wrapper** - Wraps all refined sub-regions into hierarchy: - - Creates COMPOSITE region at same level as input root - - Adds all refined LEAF regions as children - - Computes input/output boundaries for composite - - Returns single COMPOSITE containing hierarchical structure - - **Transformation Example:** - ``` - Input (flat LEAF region from partitioner): - Region(nodes=[0,1,2,3,4,5,6,7,8]) - - After Stage 1 (converged patterns): - [Region{0,1,2}, Region{3,4,5,6,7,8}] # Found one convergence - - After Stage 2 (sequence splitting): - [Region{0,1,2}, Region{3}, Region{4,5,6}, Region{7,8}] - - After Stage 3 (composite wrapping): - COMPOSITE { - LEAF{0,1,2}, # Convergence pattern - LEAF{3}, # Boundary op - LEAF{4,5,6}, # Merged sequence - LEAF{7,8} # Merged sequence - } - ``` - - **Benefits:** - - **Better Granularity:** Not too coarse, not too fine - - **Pattern Recognition:** Convergence patterns kept together - - **Optimization-Friendly:** Boundary ops isolated for targeting - - **Hierarchical:** Enables recursive optimization strategies - - **Invariants Maintained:** - - Total node count unchanged (reorganization only) - - All nodes assigned to exactly one LEAF region - - LEAF regions don't overlap - - Parent-child relationships properly formed - - **Output Format:** - Always returns a single region: - - If input had >1 nodes: COMPOSITE region with LEAF children - - If input had 1 node: That single LEAF region unchanged - - Returns: - COMPOSITE region containing hierarchically organized LEAF sub-regions. - The composite represents the same nodes as input root but with - better internal structure reflecting computation patterns. - - Example: - >>> builder = TopDownRegionBuilder(graph, flat_region, next_id=10) - >>> refined = builder.build_composite_region() - >>> print(f"Type: {refined.get_type()}") # COMPOSITE - >>> print(f"Children: {len(refined.get_children())}") # 4-10 typically - >>> for child in refined.get_children(): - ... print(f" {child.get_id()}: {child.get_size()} nodes") - """ + """Refine a flat region into a hierarchical COMPOSITE region.""" # merge converged regions into composite regions self.regions = self._merge_converged_regions(self.root) # split sequence regions into smaller regions @@ -1860,77 +753,8 @@ class CombinedRegionSearch(RegionSearchBase): This class implements a sophisticated region discovery algorithm that combines two complementary strategies to create well-formed, hierarchical regions from an ONNX - computation graph: - - **Phase 1: Bottom-Up Partitioning (RegionPartitioner)** - - Traverses the graph from inputs to outputs - - Identifies divergent nodes (nodes with outputs consumed by multiple branches) - - Finds convergence points where divergent branches rejoin - - Creates initial LEAF regions based on divergence/convergence patterns - - Groups linear sequences of non-divergent nodes together - - **Phase 2: Top-Down Refinement (TopDownRegionBuilder)** - - Takes each region from Phase 1 as input - - Identifies and merges converged sub-regions within each region - - Splits long sequences into smaller, more manageable regions - - Creates COMPOSITE regions with hierarchical structure - - Ensures region boundaries align with natural computation patterns - - **Key Features:** - - **Comprehensive Coverage:** Visits all nodes in the graph - - **Hierarchical Structure:** Creates multi-level region hierarchies - - **Pattern Recognition:** Identifies divergence/convergence patterns - - **Boundary Computation:** Automatically computes input/output tensors for each region - - **Quality Metrics:** Provides coverage and node count statistics - - **Region Types Created:** - - LEAF regions: Basic building blocks containing graph nodes - - COMPOSITE regions: Higher-level regions containing child regions - - **Use Cases:** - - Graph partitioning for distributed execution - - Identifying optimization boundaries for quantization/pruning - - Creating sub-graphs for incremental processing - - Analyzing graph structure and dependencies - - **Algorithm Overview:** - 1. Initialize RegionPartitioner for bottom-up search - 2. Partition graph into initial LEAF regions - 3. For each initial region: - a. Merge converged sub-regions - b. Split long sequences into smaller regions - c. Create COMPOSITE region hierarchy - 4. Compute final region boundaries - - **Output:** - A list of COMPOSITE regions that collectively cover the entire graph, - each containing a hierarchical structure of child regions. - - **Example:** - >>> search = CombinedRegionSearch(graph) - >>> regions = search.search_regions() - >>> print(f"Created {len(regions)} top-level regions") - >>> for region in regions: - ... print(f"Region {region.get_id()}: {region.get_size()} nodes") - - **Performance Considerations:** - - Complexity depends on graph structure (divergence/convergence patterns) - - Pre-computes forward-reachable nodes for efficient convergence detection - - Uses BFS for systematic graph traversal - - **Validation:** - - Logs warnings if node counts change during refinement - - Verifies coverage of all nodes in the graph - - Ensures no duplicate nodes across regions - - Attributes: - graph: The ONNX graph to partition (onnx_graphsurgeon.Graph) - regions: List of top-level COMPOSITE regions created by the search - region_partitioner: Internal RegionPartitioner instance - root: Root region containing all graph nodes (inherited from RegionSearchBase) - tensor_users_map: Mapping from tensor names to consuming node indices - forward_reachable_nodes_map: Pre-computed forward reachability information - maximum_sequence_region_size: Maximum nodes per sequence region during merging + computation graph. + """ def __init__( @@ -1939,410 +763,63 @@ def __init__( maximum_sequence_region_size: int = 10, minimum_topdown_search_size: int = 10, ): - """Initialize CombinedRegionSearch for a given ONNX graph. - - Sets up the necessary data structures for two-phase region search: - - Initializes base class with graph and builds root region - - Creates empty regions list for storing results - - Initializes RegionPartitioner for Phase 1 bottom-up search - - Pre-computes tensor users map and forward reachability information - - Args: - graph: The ONNX graph to partition (onnx_graphsurgeon.Graph). - Must be a valid, connected computation graph. - maximum_sequence_region_size: Maximum nodes per sequence region during merging - in Phase 2 refinement (default: 10) - minimum_topdown_search_size: Minimum nodes per region to search during top-down refinement (default: 10) - - Note: - Initialization performs pre-computation that scales with graph size. - For very large graphs, this may take significant time. - - Example: - >>> import onnx_graphsurgeon as gs - >>> import onnx - >>> model = onnx.load("model.onnx") - >>> graph = gs.import_onnx(model) - >>> search = CombinedRegionSearch(graph, maximum_sequence_region_size=10) - """ + """Initialize CombinedRegionSearch for a given ONNX graph.""" super().__init__(graph) self.regions: list[Region] = [] - self.region_partitioner = RegionPartitioner(graph) self.minimum_topdown_search_size = minimum_topdown_search_size self.maximum_sequence_region_size = maximum_sequence_region_size def search_regions(self) -> list[Region]: - """Execute two-phase region search to partition the graph into hierarchical regions. - - This is the main entry point for the CombinedRegionSearch algorithm. It performs - a sophisticated two-phase analysis of the computation graph: - - **Phase 1: Bottom-Up Partitioning** - Uses RegionPartitioner to create initial regions by: - - Traversing graph from inputs to outputs - - Identifying divergent nodes (where computation branches) - - Finding convergence points (where branches rejoin) - - Grouping linear sequences of operations - - Creating initial LEAF regions based on these patterns - - **Phase 2: Top-Down Refinement** - For each region from Phase 1, uses TopDownRegionBuilder to: - - Identify and merge converged sub-patterns within the region - - Split long sequences into smaller, more manageable regions - - Create hierarchical COMPOSITE region structures - - Ensure optimal region granularity for optimization - - **Algorithm Steps:** - 1. Initialize RegionPartitioner with the graph - 2. Partition graph into initial regions (Phase 1) - 3. Log partitioning statistics (coverage, region count) - 4. For each initial region: - a. Create TopDownRegionBuilder for refinement - b. Share tensor users map for efficient lookups - c. Build composite region hierarchy (Phase 2) - d. Validate node count consistency - e. Recompute region boundaries - 5. Return final list of refined regions - - **Output Structure:** - Each returned region is typically a COMPOSITE region containing: - - LEAF child regions with actual graph nodes - - Computed input/output tensor boundaries - - Hierarchical structure reflecting computation patterns - - **Quality Metrics Logged:** - - Total regions found: Number of top-level regions created - - Total nodes visited: How many graph nodes were processed - - Coverage percentage: What fraction of the graph was partitioned - - **Validation:** - - Warns if node counts change during refinement (potential bug) - - Ensures all nodes are accounted for - - Verifies region boundary consistency - - Returns: - List of Region objects representing the partitioned graph. - Each region is a COMPOSITE region with a hierarchical structure - of child regions. The regions collectively cover all nodes in - the graph without overlap. - - Raises: - May propagate exceptions from RegionPartitioner or TopDownRegionBuilder - if graph structure is invalid or contains unsupported patterns. - - Example: - >>> search = CombinedRegionSearch(graph) - >>> regions = search.search_regions() - >>> print(f"Graph partitioned into {len(regions)} regions") - >>> # Analyze results - >>> total_nodes = sum(r.get_all_nodes_recursive_count() for r in regions) - >>> print(f"Total nodes in all regions: {total_nodes}") - >>> # Print hierarchical structure - >>> for region in regions: - ... search.print_tree(region) - - Note: - This method modifies self.regions and returns it. Calling this - method multiple times will overwrite previous results. - - See Also: - RegionPartitioner: Phase 1 bottom-up partitioning - TopDownRegionBuilder: Phase 2 top-down refinement - print_tree: Visualize the resulting region hierarchy - """ - # ===================================================================== - # PHASE 1: Bottom-Up Partitioning - # ===================================================================== - # Create a fresh RegionPartitioner instance for this search. - # This performs initial graph analysis including: - # - Building tensor-to-users mapping for tracking data flow - # - Computing forward reachability for convergence detection + """Execute two-phase region search to partition the graph into hierarchical regions.""" logger.info("Phase 1: Bottom-up partitioning") logger.debug("Initializing RegionPartitioner") region_partitioner = RegionPartitioner(self.graph) # Execute the bottom-up partitioning algorithm. - # This traverses the graph and creates initial LEAF regions based on: - # - Divergence/convergence patterns (where computation branches/rejoins) - # - Linear sequences of non-divergent nodes - # - Graph structure and operation types self.regions = region_partitioner.partition_graph() - # ===================================================================== - # Log Phase 1 Results - # ===================================================================== - # Report statistics about the initial partitioning to help understand - # graph structure and verify complete coverage. coverage_pct = ( - 100 * len(self.region_partitioner.visited_nodes) / len(self.graph.nodes) + 100 * len(region_partitioner.visited_nodes) / len(self.graph.nodes) if self.graph.nodes else 0 ) logger.info( f"Phase 1 complete: {len(self.regions)} regions, " - f"{len(self.region_partitioner.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" + f"{len(region_partitioner.visited_nodes)}/{len(self.graph.nodes)} nodes ({coverage_pct:.1f}%)" ) logger.debug("Proceeding to Phase 2: Top-down refinement") - # ===================================================================== - # PHASE 2: Top-Down Refinement - # ===================================================================== - # Track the next available region ID to ensure unique IDs across all regions. - # This is important because we'll be creating new regions during refinement. logger.info("Phase 2: Top-down refinement") next_region_id = region_partitioner.current_region_id - # Process each initial region to refine its structure. - # Each region from Phase 1 becomes a root for hierarchical refinement. refined_count = 0 - skipped_count = 0 - for idx in range(len(self.regions)): - total_nodes = len(self.regions[idx].get_all_nodes_recursive()) - if total_nodes < self.minimum_topdown_search_size: - logger.debug(f"Skipping region {idx}: {total_nodes} nodes (below minimum)") - skipped_count += 1 + for idx, region in enumerate(self.regions): + node_count = len(region.get_region_nodes_and_descendants()) + if node_count < self.minimum_topdown_search_size: + logger.debug(f"Skipping region {idx}: {node_count} nodes (below minimum)") continue - # Create a TopDownRegionBuilder for this specific region. - # This builder will analyze the region and create a hierarchical - # structure of child regions based on internal patterns. - logger.debug(f"Refining region {idx}: {total_nodes} nodes") + logger.debug(f"Refining region {idx}: {node_count} nodes") region_builder = TopDownRegionBuilder( self.graph, - self.regions[idx], + region, next_region_id=next_region_id, maximum_sequence_region_size=self.maximum_sequence_region_size, + tensor_users_map=region_partitioner.tensor_users_map, + forward_reachable_nodes_map=region_partitioner.forward_reachable_nodes_map, ) - # Share the tensor users map from Phase 1 to avoid recomputation. - # This map is expensive to build and is shared across all refinements. - region_builder.tensor_users_map = region_partitioner.tensor_users_map - - # Track node count for validation. - # The refinement should reorganize nodes into hierarchies without - # losing or duplicating any nodes. - node_count_before = len(self.regions[idx].get_all_nodes_recursive()) - - # Execute top-down refinement on this region. - # This creates a COMPOSITE region with hierarchical structure: - # 1. Merges converged sub-regions (nodes between divergence/convergence) - # 2. Splits long sequences into smaller regions - # 3. Creates appropriate parent-child relationships - self.regions[idx] = region_builder.build_composite_region() - - # Validate that refinement preserved all nodes. - # A mismatch indicates a bug in the refinement logic. - node_count_after = len(self.regions[idx].get_all_nodes_recursive()) - if node_count_before != node_count_after: + node_count_after = len(self.regions[idx].get_region_nodes_and_descendants()) + if node_count != node_count_after: logger.warning( - f"Node count mismatch in region {idx}: {node_count_before} → {node_count_after}" + f"Node count mismatch in region {idx}: {node_count} → {node_count_after}" ) - # Recompute region boundaries after refinement. - # The hierarchical structure may have changed the input/output - # tensors at the top level of this region. region_partitioner.compute_region_boundaries(self.regions[idx]) - - # Update next_region_id for the next iteration. - # Each builder may have created new regions with new IDs. next_region_id = region_builder.next_region_id refined_count += 1 - logger.info(f"Phase 2 complete: refined {refined_count} regions, skipped {skipped_count}") + logger.info(f"Phase 2 complete: refined {refined_count}/{len(self.regions)} regions") - # Return the final refined regions return self.regions - - -# ============================================================================= -# Region Search Inspection Tool -# ============================================================================= - - -def inspect_region_search( - onnx_path: str, - max_sequence_size: int = 10, - include_all_regions: bool = False, -) -> list[Region]: - """Inspect region search results for an ONNX model. - - This function loads an ONNX model, runs CombinedRegionSearch (which performs - both bottom-up partitioning and top-down refinement internally), and prints - detailed information about the discovered regions including their hierarchical - structure. - - **What it does:** - 1. Loads ONNX model and converts to GraphSurgeon format - 2. Creates CombinedRegionSearch instance with specified parameters - 3. Runs two-phase search (partitioning + refinement) via search() - 4. Displays detailed region structure and statistics - 5. Returns the final list of refined regions - - **Output Sections:** - - Initialization: Shows search parameters - - Two-Phase Search: Runs automatically via CombinedRegionSearch.search() - - Detailed Structure: Shows each region's hierarchy and properties - - Summary Statistics: Shows region counts and node coverage - - Args: - onnx_path: Path to the ONNX model file - max_sequence_size: Maximum size for sequence regions during refinement (default: 10) - include_all_regions: Include all regions, even those without major quantizable - operations (Conv, MatMul, etc.). Default: False (skips such regions) - - Returns: - List of discovered and refined regions (LEAF and COMPOSITE) - - Example: - >>> # Inspect model with default settings - >>> regions = inspect_region_search("model.onnx") - >>> print(f"Found {len(regions)} regions") - >>> - >>> # Custom sequence size - >>> regions = inspect_region_search("model.onnx", max_sequence_size=20) - >>> - >>> # Include all regions - >>> regions = inspect_region_search("model.onnx", include_all_regions=True) - """ - # Load ONNX model - logger.info(f"Loading model: {onnx_path}") - onnx_model = onnx.load(onnx_path) - - # Convert to onnx_graphsurgeon Graph - graph = gs.import_onnx(onnx_model) - graph.cleanup().toposort() - logger.info( - f"Loaded graph: {len(graph.nodes)} nodes, {len(graph.inputs)} inputs, {len(graph.outputs)} outputs" - ) - - # Initialize CombinedRegionSearch (contains RegionPartitioner internally) - logger.debug( - f"Search parameters: max_steps={DEFAULT_MAX_STEPS}, max_sequence_size={max_sequence_size}" - ) - - combined_search = CombinedRegionSearch(graph, maximum_sequence_region_size=max_sequence_size) - - # Run complete two-phase region search - logger.info("Running region search") - regions = combined_search.search_regions() - - # Show detailed region structure - logger.info("Analyzing region structure") - all_regions = [] - for i, region in enumerate(regions): - for child in region.get_children(): - if not include_all_regions and not has_quantizable_operations(child, graph): - region.remove_child(child) - if not include_all_regions and not has_quantizable_operations(region, graph): - logger.debug(f"Filtered out region {i} (no quantizable operations)") - continue - logger.debug( - f"Region {i}: {region.get_type().value}, {len(region.get_all_nodes_recursive())} nodes, " - f"{len(region.inputs)} inputs, {len(region.outputs)} outputs" - ) - all_regions.append(region) - if region.get_type() == RegionType.COMPOSITE: - logger.debug(f" {len(region.get_children())} child regions") - all_regions.extend(region.get_children()) - combined_search.print_tree(region, indent=2) - - # Summary statistics - leaf_regions = sum(1 for r in all_regions if r.get_type() == RegionType.LEAF) - composite_regions = sum(1 for r in all_regions if r.get_type() == RegionType.COMPOSITE) - - all_nodes = set() - for region in all_regions: - all_nodes.update(region.get_all_nodes_recursive()) - total_nodes = len(all_nodes) - coverage_pct = 100 * total_nodes / len(graph.nodes) if graph.nodes else 0 - - logger.info( - f"Summary: {len(all_regions)} regions ({leaf_regions} LEAF, {composite_regions} COMPOSITE), " - f"{total_nodes}/{len(graph.nodes)} nodes ({coverage_pct:.1f}%)" - ) - - # Print histogram of region sizes - region_sizes = [ - len(r.get_all_nodes_recursive()) for r in all_regions if r.get_type() == RegionType.LEAF - ] - - if region_sizes: - min_size = min(region_sizes) - max_size = max(region_sizes) - avg_size = sum(region_sizes) / len(region_sizes) - - logger.info(f"LEAF region sizes: min={min_size}, max={max_size}, avg={avg_size:.1f}") - - # Create histogram bins - size_counts = Counter(region_sizes) - - # Display histogram - logger.debug("Size distribution:") - for size in sorted(size_counts.keys()): - count = size_counts[size] - bar = "█" * min(count, 50) # Cap bar length at 50 - logger.debug(f" {size:4d} nodes: {bar} ({count} regions)") - - return regions - - -def main(): - """Command-line entry point for region search inspection.""" - parser = argparse.ArgumentParser( - prog="modelopt.onnx.quantization.autotune.region_search", - description="Inspect region search results for ONNX models", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -Examples: - # Basic inspection - python -m modelopt.onnx.quantization.autotune.region_search --model model.onnx - - # Verbose mode for debug logging - python -m modelopt.onnx.quantization.autotune.region_search \\ - --model model.onnx --verbose - - # Custom maximum sequence size - python -m modelopt.onnx.quantization.autotune.region_search \\ - --model model.onnx --max-sequence-size 20 - """, - ) - - parser.add_argument("--model", "-m", type=str, required=True, help="Path to ONNX model file") - parser.add_argument( - "--max-sequence-size", - type=int, - default=10, - help="Maximum size for sequence regions during refinement (default: 10)", - ) - parser.add_argument( - "--include-all-regions", - action="store_true", - help="Include all regions, even those without major quantizable operations. " - "Default: False (skips such regions)", - ) - parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose debug logging") - - args = parser.parse_args() - - # Configure logging - log_level = logging.DEBUG if args.verbose else logging.INFO - logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(message)s") - logger.setLevel(log_level) - - # Run inspection - try: - regions = inspect_region_search( - onnx_path=args.model, - max_sequence_size=args.max_sequence_size, - include_all_regions=args.include_all_regions, - ) - logger.info(f"✓ Inspection complete: {len(regions)} top-level regions discovered") - return 0 - except Exception as e: - logger.error(f"Inspection failed: {e}", exc_info=args.verbose) - return 1 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index a30a113ec..77dec9441 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -315,7 +315,8 @@ def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[ for node_idx, node in enumerate(nodes): inputs = node.inputs if isinstance(node, gs.Node) else node.input for tensor in inputs: - tensor_consumer_map[tensor.name].append(node_idx) + tensor_name = tensor.name if isinstance(tensor, gs.Tensor) else tensor + tensor_consumer_map[tensor_name].append(node_idx) return tensor_consumer_map diff --git a/modelopt/onnx/quantization/qdq_utils.py b/modelopt/onnx/quantization/qdq_utils.py index f533a8ede..abda4dacd 100644 --- a/modelopt/onnx/quantization/qdq_utils.py +++ b/modelopt/onnx/quantization/qdq_utils.py @@ -1040,42 +1040,25 @@ def cast_initializer_to_dtype( def get_quantized_tensors(onnx_model: onnx.ModelProto) -> set[str]: """Get the names of all quantized tensors from an ONNX model. - This function identifies all QuantizeLinear nodes in the ONNX model - and extracts the names of tensors being quantized (the first input of - each QuantizeLinear node, excluding scale and zero-point inputs). + This function identifies all DequantizeLinear nodes in the ONNX model + and extracts the names of tensors being dequantized (the first input of + each DequantizeLinear node, excluding scale and zero-point inputs). Args: onnx_model: ONNX model protobuf to analyze Returns: - Set of tensor names that are inputs to QuantizeLinear nodes - (i.e., the tensors being quantized) - - Example: - >>> import onnx - >>> from modelopt.onnx.quantization.qdq_utils import get_quantized_tensors - >>> - >>> # Load a quantized model - >>> model = onnx.load("quantized_model.onnx") - >>> - >>> # Get all quantized tensor names - >>> quantized_tensors = get_quantized_tensors(model) - >>> print(f"Found {len(quantized_tensors)} quantized tensors") - >>> - >>> # Use with autotuner to import insertion points - >>> from modelopt.onnx.quantization.autotune import QDQAutotuner - >>> autotuner = QDQAutotuner(new_model) - >>> autotuner.initialize() - >>> autotuner.import_insertion_points(quantized_tensors) + Set of tensor names that are inputs to DequantizeLinear nodes + (i.e., the tensors being dequantized) """ quantized_tensors = set() for node in onnx_model.graph.node: if node.op_type == "DequantizeLinear": - # First input is the tensor being quantized + # First input is the tensor being dequantized # (inputs[1] is scale, inputs[2] is zero-point) if node.input and len(node.input) > 0: quantized_tensors.add(node.input[0]) - logger.debug(f"Found {len(quantized_tensors)} quantized tensors in ONNX model") + logger.debug(f"Found {len(quantized_tensors)} dequantized tensors in ONNX model") return quantized_tensors diff --git a/tests/unit/onnx/quantization/autotune/test_region_search.py b/tests/unit/onnx/quantization/autotune/test_region_search.py index d510cc277..e63f5c4c0 100644 --- a/tests/unit/onnx/quantization/autotune/test_region_search.py +++ b/tests/unit/onnx/quantization/autotune/test_region_search.py @@ -313,7 +313,7 @@ def test_print_tree_contains_node_info(self): # Should contain node counts assert "Direct nodes:" in result - assert "Total nodes (recursive):" in result + assert "Total nodes:" in result assert "Children:" in result def test_print_tree_contains_io_info(self): @@ -330,7 +330,6 @@ def test_print_tree_contains_io_info(self): # Should contain I/O information assert "Inputs:" in result assert "Outputs:" in result - assert "tensors" in result def test_print_tree_divergent_graph(self): """Test print_tree on a divergent graph with more complex structure.""" @@ -355,10 +354,10 @@ def test_print_tree_max_nodes_to_show(self): # Test with different max_nodes_to_show values output1 = io.StringIO() - search.print_tree(max_nodes_to_show=1, file=output1) + search.print_tree(max_items=1, file=output1) output2 = io.StringIO() - search.print_tree(max_nodes_to_show=10, file=output2) + search.print_tree(max_items=10, file=output2) # Both should produce output assert len(output1.getvalue()) > 0 From dd41ca2733cabb0713dc9efc39dfce99538fc77d Mon Sep 17 00:00:00 2001 From: Will Guo Date: Tue, 27 Jan 2026 11:14:25 +0000 Subject: [PATCH 3/5] resolve comment Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/common.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index 5418e2a58..b55ae4239 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -558,7 +558,7 @@ def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: for scheme in sorted_schemes: # Check if this scheme is too similar to any already-filtered scheme too_similar = False - scheme_to_replace = [] + schemes_to_replace = [] for existing_scheme in filtered_schemes: distance = scheme.distance(existing_scheme) if distance < self.minimum_distance: @@ -566,13 +566,13 @@ def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: too_similar = True if scheme.latency_ms < existing_scheme.latency_ms: # New scheme is better, mark existing for replacement - scheme_to_replace.append(existing_scheme) + schemes_to_replace.append(existing_scheme) if not too_similar: filtered_schemes.append(scheme) - elif scheme_to_replace: - for scheme_to_replace in scheme_to_replace: - filtered_schemes.remove(scheme_to_replace) + elif schemes_to_replace: + for scheme in schemes_to_replace: + filtered_schemes.remove(scheme) filtered_schemes.append(scheme) sorted_schemes = filtered_schemes From 6412284de6ed1fe68be010165593e2f051424e9e Mon Sep 17 00:00:00 2001 From: Will Guo Date: Wed, 28 Jan 2026 06:30:56 +0000 Subject: [PATCH 4/5] rename scheme-> to scheme_to_replace Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index b55ae4239..1eca37a0c 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -571,8 +571,8 @@ def add_pattern_schemes(self, pattern_schemes: PatternSchemes) -> None: if not too_similar: filtered_schemes.append(scheme) elif schemes_to_replace: - for scheme in schemes_to_replace: - filtered_schemes.remove(scheme) + for scheme_to_replace in schemes_to_replace: + filtered_schemes.remove(scheme_to_replace) filtered_schemes.append(scheme) sorted_schemes = filtered_schemes From 4afccc49d7800fa9551e72f75d50e97925a87898 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Fri, 30 Jan 2026 05:33:10 +0000 Subject: [PATCH 5/5] pick fixes from PR#701 Signed-off-by: Will Guo --- modelopt/onnx/op_types.py | 64 +++ .../onnx/quantization/autotune/__init__.py | 4 +- modelopt/onnx/quantization/autotune/common.py | 83 +-- .../quantization/autotune/insertion_points.py | 531 ++++++++++++++++++ .../quantization/autotune/region_pattern.py | 4 +- .../quantization/autotune/region_search.py | 14 +- .../autotune/test_region_pattern.py | 2 +- .../autotune/test_region_search.py | 5 +- 8 files changed, 650 insertions(+), 57 deletions(-) create mode 100644 modelopt/onnx/quantization/autotune/insertion_points.py diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index cd32ba17a..7e11d25e6 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -305,6 +305,70 @@ def is_data_dependent_shape_op(op_type: str): ] +def get_bool_ops(): + """Returns set of bool operations.""" + return { + "Not", + "And", + "Or", + "Xor", + } + + +def get_bitwise_ops(): + """Returns set of bitwise operations.""" + return { + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + } + + +def get_value_check_ops(): + """Returns set of value checking operations.""" + return { + "IsNaN", + "IsInf", + "Sign", + "Abs", + } + + +def get_comparison_ops(): + """Returns set of comparison operations.""" + return { + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + } + + +def get_conditional_ops(): + """Returns set of conditional operations.""" + return { + "Where", + } + + +def get_aggregation_ops(): + """Returns set of aggregation operations.""" + return { + "All", + "Any", + } + + +def get_set_ops(): + """Returns set of set/search operations.""" + return { + "Unique", + "NonZero", + } + + def get_symmetric_ops(): """Returns set of commutative/symmetric operations where operand order doesn't matter.""" return { diff --git a/modelopt/onnx/quantization/autotune/__init__.py b/modelopt/onnx/quantization/autotune/__init__.py index 3aa63c94c..f60dc917f 100644 --- a/modelopt/onnx/quantization/autotune/__init__.py +++ b/modelopt/onnx/quantization/autotune/__init__.py @@ -34,8 +34,8 @@ ) from .insertion_points import ( ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, NodeInputInsertionPoint, - RegionOutputInsertionPoint, ResolvedInsertionPoint, ) from .region_pattern import RegionPattern @@ -45,6 +45,7 @@ "AutotunerError", "AutotunerNotInitializedError", "ChildRegionInputInsertionPoint", + "ChildRegionOutputInsertionPoint", "CombinedRegionSearch", "Config", "InsertionScheme", @@ -53,7 +54,6 @@ "PatternCache", "PatternSchemes", "Region", - "RegionOutputInsertionPoint", "RegionPattern", "RegionType", "ResolvedInsertionPoint", diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index 1eca37a0c..db7c9b373 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -26,8 +26,8 @@ from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, NodeInputInsertionPoint, - RegionOutputInsertionPoint, ResolvedInsertionPoint, ) @@ -88,7 +88,14 @@ def __init__(self, region_id: int, level: int, region_type: RegionType): self.metadata: dict[str, str] = {} def get_children(self, *, sort: bool = False) -> list["Region"]: - """Get all child regions.""" + """Get all child regions. If sort is True, sort the children by level and size. + + Args: + sort: Whether to sort the children by level and size + + Returns: + List of child regions + """ if sort: return sorted( self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants()) @@ -143,14 +150,6 @@ def is_descendant_of(self, potential_ancestor: "Region") -> bool: current = current.parent return False - def add_node(self, node_index: int) -> None: - """Add a node index to this region.""" - self.nodes.add(node_index) - - def add_nodes(self, node_indices: list[int]) -> None: - """Add multiple node indices to this region.""" - self.nodes.update(node_indices) - def get_nodes(self, *, sort: bool = False) -> list[int]: """Get direct node indices in this region only.""" if sort: @@ -179,16 +178,6 @@ def contains_node_within_region_and_descendants(self, node_index: int) -> bool: """Check if region contains a node recursively.""" return node_index in self.get_region_nodes_and_descendants() - def add_input(self, tensor_name: str) -> None: - """Add an input tensor name.""" - if tensor_name not in self.inputs: - self.inputs.append(tensor_name) - - def add_output(self, tensor_name: str) -> None: - """Add an output tensor name.""" - if tensor_name not in self.outputs: - self.outputs.append(tensor_name) - def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: """Get total node count recursively including all descendants.""" if _visited is None: @@ -228,34 +217,41 @@ def compute_structural_signature(self, graph: gs.Graph) -> str: node operations, and hierarchical structure. Regions with identical signatures can share Q/DQ insertion schemes. - The signature captures: - - Node operation types and key parameters - - Hierarchical structure (child regions) - - Deterministic ordering (sorted for consistency) - Args: graph: The ONNX graph containing the region's nodes Returns: Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") """ - raise NotImplementedError("Not implemented") + from modelopt.onnx.quantization.autotune.region_pattern import RegionPattern + + return RegionPattern.from_region(self, graph).signature @dataclass class InsertionScheme: - """Q/DQ insertion specification applied to all regions matching a pattern.""" + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + """ node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) - region_outputs: list[RegionOutputInsertionPoint] = field(default_factory=list) + region_outputs: list[ChildRegionOutputInsertionPoint] = field(default_factory=list) latency_ms: float = float("inf") error: bool = False profile_timestamp: str | None = None @property def hash(self) -> str: - """Compute deterministic hash for scheme identity.""" + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + """ sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) sorted_regions = sorted( [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] @@ -275,7 +271,11 @@ def is_empty(self) -> bool: @property def is_profiled(self) -> bool: - """Check if this scheme has been profiled (measured).""" + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + """ return self.error or self.latency_ms != float("inf") def to_dict(self) -> dict[str, Any]: @@ -306,13 +306,24 @@ def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": for pt in data.get("child_region_inputs", []) ] scheme.region_outputs = [ - RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ChildRegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) ] return scheme def distance(self, other: "InsertionScheme") -> int: - """Compute edit distance between this scheme and another scheme.""" + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + """ return ( len(set(self.node_inputs).symmetric_difference(other.node_inputs)) + len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs)) @@ -403,14 +414,6 @@ def has_schemes(self) -> bool: """Check if any schemes have been added.""" return len(self.schemes) > 0 - def add_scheme(self, scheme: InsertionScheme) -> None: - """Add a scheme to the collection. - - Args: - scheme: InsertionScheme to add - """ - self.schemes.append(scheme) - def get_measured_schemes(self) -> list[InsertionScheme]: """Get schemes that have been measured (finite latency). diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py new file mode 100644 index 000000000..dd01848dd --- /dev/null +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Q/DQ insertion point management for ONNX quantization autotune. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. +""" + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnx_graphsurgeon as gs + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.common import Region + +from modelopt.onnx.op_types import ( + get_aggregation_ops, + get_bitwise_ops, + get_bool_ops, + get_comparison_ops, + get_conditional_ops, + get_copy_ops, + get_set_ops, + get_value_check_ops, + is_fusible_reduction_op, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class InsertionPoint(ABC): + """Abstract base class for pattern-relative Q/DQ insertion points.""" + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + ... + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionPoint": + """Create from dictionary.""" + ... + + @abstractmethod + def resolve(self, region: "Region", graph: gs.Graph) -> set["ResolvedInsertionPoint"]: + """Resolve pattern-relative insertion point to actual tensor names.""" + ... + + @staticmethod + @abstractmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["InsertionPoint"]: + """Collect all valid insertion points of this type from a region.""" + ... + + +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. It contains the + tensor name and the node index (if applicable) and input index (if applicable). + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + tensor_name: str + node_index: int | None = None # Absolute graph node index (or None for tensor-level insertion) + input_index: int | None = None # Input tensor index of that node (or None) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + +@dataclass(frozen=True) +class NodeInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a node's input (frozen/hashable). + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + node_index: int # Pattern-relative node index + input_index: int # Input tensor index of that node + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": + """Create from dictionary.""" + return cls(node_index=data["node_index"], input_index=data["input_index"]) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a node input insertion point to actual tensor names for a matching region.""" + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + actual_node_idx = node_indices[self.node_index] + node = graph.nodes[actual_node_idx] + assert self.input_index < len(node.inputs), "Input index out of range" + + resolved_ips = set() + # Determine which input indices to resolve (include weights for Conv/ConvTranspose) + input_indices = [self.input_index] + if node.op in ["Conv", "ConvTranspose"]: + assert self.input_index == 0, ( + "Conv/ConvTranspose inputs and weights must be quantized together" + ) + assert len(node.inputs) >= 2, "Conv/ConvTranspose should have at least 2 inputs" + input_indices.append(1) + + for idx in input_indices: + inp = node.inputs[idx] + if hasattr(inp, "name") and inp.name: + resolved_ips.add( + ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=idx + ) + ) + return resolved_ips + + @staticmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: + """Collect all valid node input insertion points from a region.""" + node_indices = region.get_nodes(sort=True) + insertion_points = [] + for local_idx, node_idx in enumerate(node_indices): + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + name = getattr(inp, "name", None) + if not name or skip_invalid_insertion_points(graph, name, node): + continue + insertion_points.append( + NodeInputInsertionPoint(node_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region's input boundary (frozen/hashable). + + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + + Only applies to COMPOSITE regions; LEAF regions have no children. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Pattern-relative child region index + region_index: int + # Input tensor index of that child region + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a child region input insertion point to actual tensor names.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return set() + + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Child region index out of range" + child_region = children_regions[self.region_index] + assert self.input_index < len(child_region.inputs), "Input index out of range" + tensor_name = child_region.inputs[self.input_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionInputInsertionPoint"]: + """Collect all valid child region input insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return [] + + insertion_points = [] + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for input_idx, inp in enumerate(child_region.inputs): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + insertion_points.append( + ChildRegionInputInsertionPoint(region_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionOutputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region or node output (frozen/hashable). + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + region_index: int | None # Pattern-relative child region index (or None) + node_index: int | None # Pattern-relative node index (or None) + output_index: int # Output tensor index + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionOutputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a region output insertion point to actual tensor names.""" + if self.region_index is not None: + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Region index out of range" + child_region = children_regions[self.region_index] + assert self.output_index < len(child_region.outputs), "Output index out of range" + tensor_name = child_region.outputs[self.output_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + if self.node_index is not None: + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + node = graph.nodes[node_indices[self.node_index]] + assert self.output_index < len(node.outputs), "Output index out of range" + tensor = node.outputs[self.output_index] + assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" + return resolve_region_io_insertion_points(None, graph, tensor.name) + + return set() + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionOutputInsertionPoint"]: + """Collect all valid region output insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + region_outputs_set = set(region.outputs) + insertion_points = [] + + # For COMPOSITE regions: collect child region outputs + if region.type != RegionType.LEAF: + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for output_idx, out in enumerate(child_region.outputs): + if out in region_outputs_set and not skip_invalid_insertion_points( + graph, out, child_region + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=local_idx, node_index=None, output_index=output_idx + ) + ) + + # For all regions: collect node outputs + for local_idx, node_idx in enumerate(region.get_nodes(sort=True)): + node = graph.nodes[node_idx] + for output_idx, out in enumerate(node.outputs): + if not (hasattr(out, "name") and out.name): + continue + if out.name in region_outputs_set and not skip_invalid_insertion_points( + graph, out.name, node + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=None, node_index=local_idx, output_index=output_idx + ) + ) + + return insertion_points + + +def skip_invalid_insertion_points( + graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" +) -> bool: + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ + from modelopt.onnx.quantization.autotune.common import Region + + if isinstance(region_or_node, Region): + node_indices = region_or_node.get_region_nodes_and_descendants() + nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] + else: + assert isinstance(region_or_node, gs.Node) + nodes = [region_or_node] + + for node in nodes: + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time + if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + return True + # Conv -> ReLU/Softmax or Conv -> BatchNormalization -> ReLU/Softmax + if node.op in ["Relu", "Softmax"]: + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + if ( + producer.op == "BatchNormalization" + and len(producer.inputs[0].inputs) == 1 + and producer.inputs[0].inputs[0].op in ["Conv", "ConvTranspose"] + ): + return True + # Conv -> BatchNormalization + if node.op == "BatchNormalization": + assert len(node.inputs) >= 1, "BN node should have more than one inputs" + if len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Filter 1: out boolean operations + if node.op in ( + get_bool_ops() + | get_bitwise_ops() + | get_value_check_ops() + | get_comparison_ops() + | get_conditional_ops() + | get_aggregation_ops() + | get_set_ops() + ) or is_fusible_reduction_op(node.op): + return True + # Filter 2: out shape operations + if node.op in get_autotuner_skip_ops(): + return True + # Filter 3: Skip operation-specific non-quantizable inputs + if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: + return True + if node.op in ["Conv", "Gemm"] and input_idx >= 2: + return True + # Filter 4: Skip non-floating-point tensors (int/bool indices, masks, etc.) + if hasattr(inp, "dtype") and inp.dtype not in [ + None, + np.float32, + np.float16, + np.float64, + ]: + return True + # Filter 5: Skip small tensors (scalars, small vectors) + if hasattr(inp, "shape") and inp.shape is not None: + if all(isinstance(s, int) for s in inp.shape): + if np.prod(inp.shape) < 8: + return True + return False + + +def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: + """Check if a region contains major quantizable operations (only checks LEAF regions). + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type != RegionType.LEAF: + return True + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops & get_autotuner_quantizable_ops()) + + +def resolve_region_io_insertion_points( + region: "Region | None", graph: gs.Graph, tensor_name: str +) -> set[ResolvedInsertionPoint]: + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ + tensor_users_map = getattr(graph, "tensor_users_map", None) or get_tensor_consumer_node_indices( + graph + ) + + node_indices: set[int] = set() + if region is not None: + node_indices.update(region.get_region_nodes_and_descendants()) + node_indices.update(tensor_users_map.get(tensor_name, [])) + + resolved = set() + for node_idx in node_indices: + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + if not skip_invalid_insertion_points(graph, tensor_name, node): + resolved.add( + ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + ) + return resolved + + +def merge_resolved_insertion_points( + graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] +) -> set[ResolvedInsertionPoint]: + """Optimize insertion points by merging node-specific insertions into tensor-level insertions. + + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible + """ + tensor_users_map = get_tensor_consumer_node_indices(graph) + node_ips = {ip for ip in resolved_insertion_points if ip.node_index is not None} + + results = resolved_insertion_points - node_ips + for tensor_name in {ip.tensor_name for ip in node_ips}: + all_users = set(tensor_users_map.get(tensor_name, [])) + qdq_users = {ip for ip in node_ips if ip.tensor_name == tensor_name} + if all_users == {ip.node_index for ip in qdq_users}: + results.add( + ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) + ) + else: + results.update(qdq_users) + return results + + +def get_autotuner_skip_ops(): + """Returns set of shape/structural operations that are not quantizable.""" + return set(get_copy_ops()) | { + # Additional indexing/scatter/reshape ops + "Compress", + "Scatter", + "ExpandDims", + "Unsqueeze", + "View", + "Pad", + # Utility ops + "Cast", + "Ceil", + "Clip", + "Identity", + "Range", + "Shape", + } + + +def get_autotuner_quantizable_ops(): + """Returns set of key operations that benefit from quantization.""" + return { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", + } diff --git a/modelopt/onnx/quantization/autotune/region_pattern.py b/modelopt/onnx/quantization/autotune/region_pattern.py index ab87abdd6..60d3225b6 100644 --- a/modelopt/onnx/quantization/autotune/region_pattern.py +++ b/modelopt/onnx/quantization/autotune/region_pattern.py @@ -24,8 +24,8 @@ from modelopt.onnx.quantization.autotune.common import InsertionScheme, Region from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, NodeInputInsertionPoint, - RegionOutputInsertionPoint, ResolvedInsertionPoint, ) @@ -173,7 +173,7 @@ def get_full_insertion_scheme(self, region: Region, graph: gs.Graph) -> Insertio scheme.child_region_inputs = ChildRegionInputInsertionPoint.collect_from_region( region, graph ) - scheme.region_outputs = RegionOutputInsertionPoint.collect_from_region(region, graph) + scheme.region_outputs = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) return scheme diff --git a/modelopt/onnx/quantization/autotune/region_search.py b/modelopt/onnx/quantization/autotune/region_search.py index 227e9b9ed..37a9b175d 100644 --- a/modelopt/onnx/quantization/autotune/region_search.py +++ b/modelopt/onnx/quantization/autotune/region_search.py @@ -58,12 +58,8 @@ def __init__( def _build_root_region(self) -> Region: """Create a root region containing all nodes in the graph.""" root = Region(region_id=0, level=0, region_type=RegionType.ROOT) - for node_idx in range(len(self.graph.nodes)): - root.add_node(node_idx) - for tensor_name in root.inputs: - root.add_input(tensor_name) - for tensor_name in root.outputs: - root.add_output(tensor_name) + root.nodes.update(range(len(self.graph.nodes))) + self.compute_region_boundaries(root) return root def _is_tensor_divergent(self, tensor_name: str) -> bool: @@ -376,7 +372,7 @@ def _append_node_to_region(self, node_idx: int): logger.debug(f"Started region {self.current_region_id}") self.current_region_id += 1 - self.current_region.add_node(node_idx) + self.current_region.nodes.add(node_idx) logger.debug( f" Added node {node_idx} ({node.op}), region size: {len(self.current_region.nodes)}" ) @@ -576,7 +572,7 @@ def _create_leaf_region(self, node_indices: set[int]) -> Region: ) self.next_region_id += 1 for node_idx in node_indices: - region.add_node(node_idx) + region.nodes.add(node_idx) self.compute_region_boundaries(region) return region @@ -600,7 +596,7 @@ def _split_sequence_regions(self, root: Region) -> list[Region]: region = Region( region_id=self.next_region_id, level=root.level + 1, region_type=RegionType.LEAF ) - region.add_node(node_idx) + region.nodes.add(node_idx) self.compute_region_boundaries(region) result_regions.append(region) self.next_region_id += 1 diff --git a/tests/unit/onnx/quantization/autotune/test_region_pattern.py b/tests/unit/onnx/quantization/autotune/test_region_pattern.py index 2a39f4cae..939cefa76 100644 --- a/tests/unit/onnx/quantization/autotune/test_region_pattern.py +++ b/tests/unit/onnx/quantization/autotune/test_region_pattern.py @@ -149,7 +149,7 @@ def _create_test_region( """Create a test region.""" region = Region(region_id, level, region_type) if node_indices: - region.add_nodes(node_indices) + region.nodes.update(node_indices) return region # ========================================================================= diff --git a/tests/unit/onnx/quantization/autotune/test_region_search.py b/tests/unit/onnx/quantization/autotune/test_region_search.py index e63f5c4c0..96f92eba7 100644 --- a/tests/unit/onnx/quantization/autotune/test_region_search.py +++ b/tests/unit/onnx/quantization/autotune/test_region_search.py @@ -172,7 +172,7 @@ def test_creation(self): # Create a root region with all nodes root = Region(region_id=0, level=0, region_type=RegionType.LEAF) for idx in range(len(graph.nodes)): - root.add_node(idx) + root.nodes.add(idx) builder = TopDownRegionBuilder(graph, root) @@ -397,8 +397,7 @@ def test_print_tree_top_down_builder(self): # Create a root region with all nodes root = Region(region_id=0, level=0, region_type=RegionType.LEAF) - for idx in range(len(graph.nodes)): - root.add_node(idx) + root.nodes.update(range(len(graph.nodes))) builder = TopDownRegionBuilder(graph, root) # Compute region I/O boundaries before building