From 841ed40faf0f57ef04ff8f4173e3232755ab6e5c Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 15 Dec 2025 08:04:14 +0000 Subject: [PATCH 1/5] Integrate Automated QDQ placement tool - part 1 Signed-off-by: Will Guo --- modelopt/onnx/op_types.py | 66 +- modelopt/onnx/quantization/autotune/common.py | 688 +++++++++ .../quantization/autotune/insertion_points.py | 897 +++++++++++ modelopt/onnx/quantization/graph_utils.py | 29 + .../autotune/test_insertion_points.py | 1331 +++++++++++++++++ .../onnx/quantization/autotune/test_region.py | 167 +++ 6 files changed, 3177 insertions(+), 1 deletion(-) create mode 100644 modelopt/onnx/quantization/autotune/common.py create mode 100644 modelopt/onnx/quantization/autotune/insertion_points.py create mode 100644 tests/unit/onnx/quantization/autotune/test_insertion_points.py create mode 100644 tests/unit/onnx/quantization/autotune/test_region.py diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index cc94a221f..30c14e90e 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str): ] -def get_copy_ops(): +def get_copy_ops() -> list[str]: """Returns list of copy operators.""" return [ "Flatten", @@ -303,3 +303,67 @@ def is_data_dependent_shape_op(op_type: str): "NonZero", "RoiAlign", ] + + +def get_bool_ops(): + """Returns set of bool operations.""" + return { + "Not", + "And", + "Or", + "Xor", + } + + +def get_bitwise_ops(): + """Returns set of bitwise operations.""" + return { + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + } + + +def get_value_check_ops(): + """Returns set of value checking operations.""" + return { + "IsNaN", + "IsInf", + "Sign", + "Abs", + } + + +def get_comparison_ops(): + """Returns set of comparison operations.""" + return { + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + } + + +def get_conditional_ops(): + """Returns set of conditional operations.""" + return { + "Where", + } + + +def get_aggregation_ops(): + """Returns set of aggregation operations.""" + return { + "All", + "Any", + } + + +def get_set_ops(): + """Returns set of set/search operations.""" + return { + "Unique", + "NonZero", + } diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py new file mode 100644 index 000000000..9ce18827a --- /dev/null +++ b/modelopt/onnx/quantization/autotune/common.py @@ -0,0 +1,688 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common data structures and types for the QDQ Autotuner. + +This module provides the foundational classes used throughout the autotuner: + +**Exceptions:** +- Region-related: RegionError +- Autotuner-related: AutotunerError, AutotunerNotInitializedError, InvalidSchemeError + +**Region Hierarchy:** +- Region: Hierarchical subgraph representation with parent/child relationships +- RegionType: Enumeration for LEAF, COMPOSITE, and ROOT regions + +**Q/DQ Insertion Specifications:** +- InsertionScheme: Collection of insertion points with performance metrics + +**Scheme Management:** +- PatternSchemes: Multiple insertion schemes for a pattern (applies to all matching regions) +- PatternCache: Collection of top schemes for multiple patterns, used as autotuning seeds + +**Configuration:** +- Config: Autotuning parameters and Q/DQ default values +""" + +import hashlib +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Optional + +import onnx_graphsurgeon as gs + +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + NodeInputInsertionPoint, + RegionOutputInsertionPoint, +) + +# Module logger +logger = logging.getLogger(__name__) + + +# Region-related Exceptions +class RegionError(Exception): + """Base exception for region-related errors.""" + + +# Autotuner-related Exceptions +class AutotunerError(Exception): + """Base exception for autotuner-related errors.""" + + +class AutotunerNotInitializedError(AutotunerError): + """Exception raised when autotuner is used without initialization.""" + + +class InvalidSchemeError(AutotunerError): + """Exception raised when an invalid scheme is referenced.""" + + +class RegionType(Enum): + """Region type enumeration for hierarchical graph structure. + + - LEAF: Atomic region containing direct nodes with no child regions + - COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes) + - ROOT: Top-level region encompassing the entire computation graph + """ + + LEAF = "LEAF" + COMPOSITE = "COMPOSITE" + ROOT = "ROOT" + + +class Region: + """Hierarchical subgraph region in an ONNX computation graph. + + A Region represents a cohesive subgraph with well-defined boundaries, supporting: + + **Hierarchical Structure:** + - Parent/child relationships forming a multi-level hierarchy + - LEAF regions contain only direct nodes + - COMPOSITE regions contain child regions (and optionally direct nodes) + - ROOT regions encompass the entire graph + + **Node Management:** + - Direct nodes: Operations directly in this region (not in children) + - Recursive nodes: All operations including those in descendant regions + + **Boundary Tracking:** + - Input tensors: Data entering the region from outside + - Output tensors: Data leaving the region to outside consumers + + **Pattern Matching:** + - Regions with identical structure share the same pattern signature + - Pattern-based optimization applies schemes to all matching regions + + Regions are the fundamental unit for Q/DQ insertion and optimization. + """ + + def __init__(self, region_id: int, level: int, region_type: RegionType): + """Initialize a new region. + + Args: + region_id: Unique identifier within the region hierarchy + level: Hierarchical level (0 = leaf, higher = more composite) + region_type: Type classification (LEAF, COMPOSITE, or ROOT) + """ + self.id = region_id + self.level = level + self.type = region_type + self.parent: Region | None = None + self.children: list[Region] = [] + self.nodes: set[int] = set() + self.inputs: list[str] = [] + self.outputs: list[str] = [] + + # ========================================================================= + # Basic Accessors + # ========================================================================= + + def get_id(self) -> int: + """Get region ID.""" + return self.id + + def set_id(self, region_id: int) -> None: + """Set region ID (for RegionBuilder use).""" + self.id = region_id + + def get_level(self) -> int: + """Get region level in hierarchy.""" + return self.level + + def set_level(self, level: int) -> None: + """Set region level in hierarchy (for RegionBuilder use).""" + self.level = level + + def get_type(self) -> RegionType: + """Get region type.""" + return self.type + + def set_type(self, region_type: RegionType) -> None: + """Set region type (for RegionBuilder use).""" + self.type = region_type + + # ========================================================================= + # Hierarchy Management + # ========================================================================= + + def get_parent(self) -> Optional["Region"]: + """Get parent region.""" + return self.parent + + def set_parent(self, parent: Optional["Region"]) -> None: + """Set parent region.""" + self.parent = parent + + def get_children(self) -> list["Region"]: + """Get all child regions.""" + return self.children + + def remove_child(self, child: "Region") -> bool: + """Remove a child region from this region's children list. + + Args: + child: The child region to remove + + Returns: + True if child was found and removed, False otherwise + """ + child_id = child.get_id() + initial_count = len(self.children) + self.children = [c for c in self.children if c.get_id() != child_id] + removed = len(self.children) < initial_count + + if removed and child.parent and child.parent.get_id() == self.id: + child.set_parent(None) + + return removed + + def add_child(self, child: "Region") -> None: + """Add a child sub-region.""" + # Prevent adding self as child + if child.get_id() == self.id: + logger.warning(f"Cannot add region {self.id} as its own child") + return + + # Prevent creating cycles: check if self is already a descendant of child + if self._is_descendant_of(child): + logger.warning( + f"Cycle detected: region {self.id} is already a descendant of region {child.get_id()}" + ) + return + + # Check if child already has a different parent + if child.parent is not None and child.parent.get_id() != self.id: + old_parent_id = child.parent.get_id() + logger.debug( + f"Re-parenting region {child.get_id()}: moving from parent {old_parent_id} to {self.id}" + ) + # Remove from old parent to maintain tree structure + child.parent.remove_child(child) + + # Check if child is already in children list + if any(c.get_id() == child.get_id() for c in self.children): + logger.debug(f"Region {child.get_id()} already child of {self.id}") + return + + self.children.append(child) + child.set_parent(self) + + def _is_descendant_of(self, potential_ancestor: "Region") -> bool: + """Check if this region is a descendant of potential_ancestor.""" + visited = set() + current = self.parent + while current: + if current.get_id() in visited: + # Already visited, there's a cycle in parents + return False + visited.add(current.get_id()) + if current.get_id() == potential_ancestor.get_id(): + return True + current = current.parent + return False + + # ========================================================================= + # Node Management + # ========================================================================= + + def add_node(self, node_index: int) -> None: + """Add a node index to this region.""" + self.nodes.add(node_index) + + def add_nodes(self, node_indices: list[int]) -> None: + """Add multiple node indices to this region.""" + self.nodes.update(node_indices) + + def get_nodes(self) -> set[int]: + """Get direct node indices in this region only. + + Returns only nodes directly owned by this region, excluding nodes + in child regions. Use get_all_nodes_recursive() for complete coverage. + + Returns: + Set of node indices (absolute positions in the graph) + """ + return self.nodes + + def get_all_nodes_recursive(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants. + + Traverses the entire subtree rooted at this region, collecting nodes + from this region and all child regions recursively. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Set of all node indices in this region and its descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + logger.warning(f"Cycle detected in region {self.id} during node traversal") + return set() + + _visited.add(self.id) + all_nodes = set(self.nodes) + for child in self.children: + all_nodes.update(child.get_all_nodes_recursive(_visited)) + return all_nodes + + def contains_node(self, node_index: int) -> bool: + """Check if region contains a specific node (direct only).""" + return node_index in self.nodes + + def contains_node_recursive(self, node_index: int, _visited: set[int] | None = None) -> bool: + """Check if region contains a node recursively.""" + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + return False + + _visited.add(self.id) + + if self.contains_node(node_index): + return True + return any(child.contains_node_recursive(node_index, _visited) for child in self.children) + + # ========================================================================= + # Input/Output Management + # ========================================================================= + + def add_input(self, tensor_name: str) -> None: + """Add an input tensor name.""" + if tensor_name not in self.inputs: + self.inputs.append(tensor_name) + + def add_output(self, tensor_name: str) -> None: + """Add an output tensor name.""" + if tensor_name not in self.outputs: + self.outputs.append(tensor_name) + + def get_inputs(self) -> list[str]: + """Get region input tensors.""" + return self.inputs + + def get_outputs(self) -> list[str]: + """Get region output tensors.""" + return self.outputs + + # ========================================================================= + # Size and Query Methods + # ========================================================================= + + def get_size(self) -> int: + """Get the number of direct nodes in this region. + + Returns: + Count of nodes directly in this region (excludes child regions) + """ + return len(self.nodes) + + def get_total_size(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants. + + Computes the sum of nodes in this region and all child regions, + providing the total footprint of the region subtree. + + Args: + _visited: Internal parameter for cycle detection (do not use) + + Returns: + Total number of nodes in this region and all descendants + """ + if _visited is None: + _visited = set() + + # Detect cycles + if self.id in _visited: + logger.warning(f"Cycle detected in region {self.id} during size calculation") + return len(self.nodes) + + _visited.add(self.id) + total = len(self.nodes) + for child in self.children: + total += child.get_total_size(_visited) + return total + + # ========================================================================= + # Region Operations + # ========================================================================= + + def merge(self, other: "Region") -> None: + """Merge another region into this one. + + Combines the nodes and children from the other region into this region. + The other region's children become children of this region, updating + their parent references accordingly. + + Args: + other: Region to merge into this one + """ + if not other: + return + # Merge direct nodes + self.nodes.update(other.nodes) + # Merge children (updates their parent references) + for child in other.children: + self.add_child(child) + + # ========================================================================= + # String Representation + # ========================================================================= + + def to_string(self) -> str: + """Print region information for debugging.""" + type_str = self.type.value + return ( + f"Region[id={self.id}, level={self.level}, type={type_str}, " + f"nodes={len(self.nodes)}, children={len(self.children)}, " + f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" + ) + + def __str__(self) -> str: + return self.to_string() + + def __repr__(self) -> str: + return self.to_string() + + def compute_structural_signature(self, graph: gs.Graph) -> str: + """Compute deterministic structural signature for pattern matching. + + Creates a signature that uniquely identifies the region's topology, + node operations, and hierarchical structure. Regions with identical + signatures can share Q/DQ insertion schemes. + + The signature captures: + - Node operation types and key parameters + - Hierarchical structure (child regions) + - Deterministic ordering (sorted for consistency) + + Args: + graph: The ONNX graph containing the region's nodes + + Returns: + Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") + """ + raise NotImplementedError("Not implemented") + + +# ============================================================================= +# Autotuner Q/DQ Insertion Specifications +# ============================================================================= + + +@dataclass +class InsertionScheme: + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + + **Scheme Identity:** + - Uniquely identified by the combination of insertion points (computed hash) + - latency_ms is a measured performance metric, not part of identity + - Two schemes with same insertion points but different latencies are considered identical + + **Application:** + - Node insertion points: Q/DQ at node inputs within the pattern + - Region insertion points: Q/DQ at child region boundaries (COMPOSITE only) + - All are resolved to actual configurations for each matching region + + **Performance Tracking:** + - latency_ms: Measured performance (inf = not yet measured) + - error: Whether this scheme encountered an error during measurement + - Used to select the best scheme for each pattern + + **Attributes:** + node_inputs: Q/DQ insertions at node inputs (list of NodeInputInsertionPoint) + child_region_inputs: Q/DQ insertions at child boundaries (list of ChildRegionInputInsertionPoint) + region_outputs: Q/DQ insertions at region outputs (list of RegionOutputInsertionPoint) + latency_ms: Measured latency in milliseconds (inf if not measured) + error: True if scheme measurement failed, False otherwise + profile_timestamp: ISO format timestamp when this scheme was profiled (None if not yet profiled) + """ + + node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) + child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) + region_outputs: list[RegionOutputInsertionPoint] = field(default_factory=list) + latency_ms: float = float("inf") + error: bool = False + profile_timestamp: str | None = None + + @property + def hash(self) -> str: + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + + **Hash Input:** + - Sorted node_inputs (for deterministic ordering) + - Sorted child_region_inputs (for deterministic ordering) + - Sorted region_outputs (for deterministic ordering) + - latency_ms is EXCLUDED (performance metric, not identity) + + **Use Cases:** + - Detect duplicate schemes before measurement + - Group schemes by configuration + - Efficient scheme comparison + + Returns: + 32-character hexadecimal string (SHA-256 truncated to 128 bits) + """ + # Sort points for deterministic hashing + sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) + sorted_regions = sorted( + [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] + ) + sorted_region_outputs = sorted( + [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] + ) + + # Create hash input string + hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" + + # Compute SHA-256 hash (128 bits) + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] + + @property + def is_empty(self) -> bool: + """Check if this is a baseline scheme with no Q/DQ insertions. + + Returns: + True if scheme has no node/region insertion points + """ + return ( + len(self.node_inputs) == 0 + and len(self.child_region_inputs) == 0 + and len(self.region_outputs) == 0 + ) + + @property + def has_error(self) -> bool: + """Check if this scheme encountered an error during measurement. + + Returns: + True if scheme has error=True, False otherwise + """ + return self.error + + @property + def is_profiled(self) -> bool: + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + + Returns: + True if scheme has been measured (latency_ms != inf) or has error, + False if scheme is waiting to be profiled (error=False and latency_ms=inf) + """ + return self.error or self.latency_ms != float("inf") + + @property + def num_node_insertions(self) -> int: + """Get count of node-level Q/DQ insertion points. + + Returns: + Number of NodeInputInsertionPoint entries + """ + return len(self.node_inputs) + + @property + def num_region_insertions(self) -> int: + """Get count of region-level Q/DQ insertion points. + + These specify Q/DQ insertions at child region boundaries within + COMPOSITE regions. + + Returns: + Number of ChildRegionInputInsertionPoint entries + """ + return len(self.child_region_inputs) + + @property + def num_region_output_insertions(self) -> int: + """Get count of region output insertion points. + + These specify Q/DQ insertions at outputs from child regions or nodes. + + Returns: + Number of RegionOutputInsertionPoint entries + """ + return len(self.region_outputs) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "latency_ms": self.latency_ms, + "error": self.error, + "profile_timestamp": self.profile_timestamp, + "nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs], + "child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs], + "region_outputs": [pt.to_dict() for pt in self.region_outputs], + "hash": self.hash, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": + """Create InsertionScheme from serialized dictionary. + + Reconstructs the insertion scheme from saved data, including node and + region insertion points. The hash is automatically recomputed from all + components to ensure consistency. + + Args: + data: Dictionary containing 'latency_ms', 'nodes_insertion_points', + 'child_region_inputs', and 'region_outputs' keys + + Returns: + Reconstructed InsertionScheme instance + """ + scheme = cls() + scheme.latency_ms = data.get("latency_ms", float("inf")) + scheme.error = data.get("error", False) + scheme.profile_timestamp = data.get("profile_timestamp") + + scheme.node_inputs = [ + NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", []) + ] + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint.from_dict(pt) + for pt in data.get("child_region_inputs", []) + ] + scheme.region_outputs = [ + RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ] + + # Note: hash is computed from points, so we don't load it from dict + # This ensures consistency even if stored hash differs + + return scheme + + def distance(self, other: "InsertionScheme") -> int: + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + **Distance Calculation:** + - Counts insertion points in self but not in other (need to be removed) + - Counts insertion points in other but not in self (need to be added) + - Considers all three types of insertion points: + * node_inputs + * child_region_inputs + * region_outputs + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + + Example: + >>> scheme1 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(1, 0), + ... ] + ... ) + >>> scheme2 = InsertionScheme( + ... node_inputs=[ + ... NodeInputInsertionPoint(0, 0), + ... NodeInputInsertionPoint(2, 0), + ... ] + ... ) + >>> scheme1.distance(scheme2) # 2 (remove (1,0), add (2,0)) + 2 + """ + # Convert insertion points to sets for efficient set operations + self_nodes = set(self.node_inputs) + other_nodes = set(other.node_inputs) + + self_regions = set(self.child_region_inputs) + other_regions = set(other.child_region_inputs) + + self_region_outputs = set(self.region_outputs) + other_region_outputs = set(other.region_outputs) + + # Compute symmetric difference (elements in either set but not both) + # This gives us the total number of add + remove operations + node_distance = len(self_nodes.symmetric_difference(other_nodes)) + region_distance = len(self_regions.symmetric_difference(other_regions)) + region_output_distance = len(self_region_outputs.symmetric_difference(other_region_outputs)) + + return node_distance + region_distance + region_output_distance + + def __str__(self) -> str: + """String representation for debugging.""" + error_str = ", error=True" if self.error else "" + return ( + f"InsertionScheme(node_insertions={self.num_node_insertions}, " + f"region_insertions={self.num_region_insertions}, " + f"region_output_insertions={self.num_region_output_insertions}, " + f"latency={self.latency_ms:.3f}ms{error_str})" + ) diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py new file mode 100644 index 000000000..32722e44f --- /dev/null +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -0,0 +1,897 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Q/DQ Insertion Point Management for ONNX Quantization. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. + +Core Concepts: +-------------- +1. **Pattern-Relative Insertion Points**: Insertion points are defined relative to region patterns + rather than absolute node IDs, enabling scheme reuse across all matching regions. + +2. **Resolution Process**: Pattern-relative indices are resolved to actual tensor names for each + specific region instance, then Q/DQ pairs are inserted at the resolved locations. + +3. **Hierarchical Support**: Supports Q/DQ insertion at multiple levels: + - Node inputs within regions + - Child region boundaries (inputs/outputs) + - Region outputs + +Classes: +-------- +- ResolvedInsertionPoint: Resolved Q/DQ insertion point with actual tensor name +- NodeInputInsertionPoint: Pattern-relative insertion point at node inputs +- ChildRegionInputInsertionPoint: Pattern-relative insertion point at child region inputs +- RegionOutputInsertionPoint: Pattern-relative insertion point at region/node outputs + +Utilities: +---------- +- skip_invalid_insertion_points(): Filter out non-quantizable tensors +- has_quantizable_operations(): Check if region contains major quantizable ops +- resolve_region_io_insertion_points(): Resolve region I/O to actual insertion points +- merge_resolved_insertion_points(): Merge insertion points when all users are quantized + +Constants: +---------- +- BOOL_OPERATIONS: Boolean/comparison operations (not quantizable) +- SHAPE_OPERATIONS: Shape manipulation operations (not quantizable) +- MAJOR_QUANTIZABLE_OPERATIONS: Key operations that benefit from quantization +""" + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnx_graphsurgeon as gs + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.common import Region + +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +BOOL_OPERATIONS = { + "Not", + "And", + "Or", + "Xor", + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + "IsNaN", + "IsInf", + "Sign", + "Abs", + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + "Where", + "Max", + "Min", + "Mean", + "Median", + "ArgMax", + "ArgMin", + "ReduceMax", + "ReduceMin", + "ReduceSum", + "ReduceMean", + "All", + "Any", + "Unique", + "NonZero", + "TopK", +} + +SHAPE_OPERATIONS = { + "Cast", + "Ceil", + "Clip", + "Compress", + "Concat", + "ExpandDims", + "Flatten", + "Gather", + "GatherElements", + "GatherND", + "Identity", + "Pad", + "Range", + "Scatter", + "ScatterND", + "Shape", + "Slice", + "Split", + "Squeeze", + "Tile", + "Transpose", + "Unsqueeze", + "View", +} + +MAJOR_QUANTIZABLE_OPERATIONS = { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", +} + + +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. + + **Insertion Modes:** + 1. Node-specific insertion (node_index and input_index are set): + - Inserts Q/DQ at a specific input of a specific node + - More precise control over where quantization happens + 2. Tensor-level insertion (node_index and input_index are None): + - Inserts Q/DQ for all users of the tensor + - Used when all consumers of a tensor should be quantized together + + **Attributes:** + - tensor_name: Name of the tensor where Q/DQ should be inserted + - node_index: Absolute graph node index (not pattern-relative), or None for tensor-level insertion + - input_index: Input tensor index of that node, or None for tensor-level insertion + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + tensor_name: str + # Absolute graph node index (or None for tensor-level insertion) + node_index: int | None = None + # Input tensor index of that node (or None) + input_index: int | None = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "tensor_name": self.tensor_name, + "node_index": self.node_index, + "input_index": self.input_index, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": + """Create from dictionary.""" + return cls( + tensor_name=data["tensor_name"], + node_index=data["node_index"], + input_index=data.get("input_index"), + ) + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"ResolvedInsertionPoint(tensor_name={self.tensor_name}, " + f"node={self.node_index}, input={self.input_index})" + ) + + +@dataclass(frozen=True) +class NodeInputInsertionPoint: + """Pattern-relative Q/DQ insertion point at a node's input. + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + **Resolution Process:** + 1. Pattern-relative indices (node_index, input_index) are defined once + 2. For each matching region, indices are resolved to actual tensor names + 3. Q/DQ pairs are inserted at the resolved tensor locations + + **Example:** + - NodeInputInsertionPoint(node_index=0, input_index=1) + - Resolves to: the second input (index 1) of the first node (index 0) in the pattern + - Actual tensor name depends on the specific region instance + + **Attributes:** + - node_index: Index of the node within the pattern's sorted node list (0-based) + - input_index: Index of the input tensor for that node (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Pattern-relative node index + node_index: int + # Input tensor index of that node + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return {"node_index": self.node_index, "input_index": self.input_index} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": + """Create from dictionary.""" + return cls(node_index=data["node_index"], input_index=data["input_index"]) + + def __str__(self) -> str: + """String representation for debugging.""" + return f"NodeInputInsertionPoint(node={self.node_index}, input={self.input_index})" + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a node input insertion point to actual tensor names for a matching region. + + Converts pattern-relative node/input indices to absolute node indices and actual + tensor names in the graph. Special handling for Conv/ConvTranspose operations + automatically includes weight quantization when input is quantized. + + Args: + region: The region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names + """ + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + resolved_ips = set() + + # Map from pattern-relative node index to absolute graph node index + assert self.node_index < len(node_indices), "Node index out of range" + actual_node_idx = node_indices[self.node_index] + assert actual_node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[actual_node_idx] + assert self.input_index < len(node.inputs), "Input index out of range" + + # Resolve the input tensor name using input_index + inp = node.inputs[self.input_index] + if hasattr(inp, "name") and inp.name: + ip = ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=self.input_index + ) + resolved_ips.add(ip) + + if node.op in ["Conv", "ConvTranspose"]: + assert self.input_index == 0, ( + "Conv and ConvTranspose inputs and weights should be quantized at same time" + ) + assert len(node.inputs) >= 2, "Conv and ConvTranspose should have at least 2 inputs" + inp = node.inputs[1] + if hasattr(inp, "name") and inp.name: + ip = ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=1 + ) + resolved_ips.add(ip) + + return resolved_ips + + @staticmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: + """Collect all valid node input insertion points from a region. + + Analyzes each node in the region and identifies all valid input tensors + where Q/DQ pairs could be inserted. Filters out invalid insertion points + using skip_invalid_insertion_points(). + + Args: + region: The region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of NodeInputInsertionPoint objects representing valid insertion locations + """ + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + + node_input_insertion_points = [] + for local_idx, node_idx in enumerate(node_indices): + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + # Analyze each input of the node + for input_idx, inp in enumerate(node.inputs): + # Skip if tensor doesn't have a valid name + if not (hasattr(inp, "name") and inp.name): + continue + # Skip if insertion point is invalid (wrong dtype, small size, special input, etc.) + if skip_invalid_insertion_points(graph, inp.name, node): + continue + # Create insertion point for valid tensor + ip = NodeInputInsertionPoint( + # Pattern-relative node index + node_index=local_idx, + input_index=input_idx, + ) + node_input_insertion_points.append(ip) + + return node_input_insertion_points + + +@dataclass(frozen=True) +class ChildRegionInputInsertionPoint: + """Pattern-relative Q/DQ insertion point at a child region's input boundary. + + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + + **Use Case:** + Parent regions can insert Q/DQ pairs at child region inputs to: + - Add quantization at child boundaries even if the child has no internal Q/DQ + - Override or supplement the child's own boundary Q/DQ decisions + - Apply different quantization schemes based on the parent context + + **Resolution Process:** + 1. Pattern-relative indices (region_index, input_index) are defined once + 2. For each matching parent region, indices resolve to actual child boundaries: + - region_index identifies which child region (in parent's sorted child list) + - input_index identifies which input tensor of that child region + 3. Q/DQ pairs are inserted at the resolved child input tensor locations + + **Example:** + - ChildRegionInputInsertionPoint(region_index=0, input_index=1) + - Resolves to: the second input tensor (index 1) of the first child region (index 0) + - Actual tensor name depends on the specific parent/child region instances + + **Note:** Only applies to COMPOSITE regions. LEAF regions have no children, + so child region insertion points have no effect there. + + **Attributes:** + - region_index: Index of the child region within the parent pattern's sorted child list (0-based) + - input_index: Index of the input tensor for that child region (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Index of the child region within the parent pattern's sorted child list (0-based) + region_index: int + # Index of the input tensor for that child region (0-based) + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return {"region_index": self.region_index, "input_index": self.input_index} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": + """Create from dictionary. + + Backward compatible: Ignores obsolete fields like 'child_region_id' + from older serialization formats. + + Args: + data: Dictionary with 'region_index' and 'input_index' keys + + Returns: + ChildRegionInputInsertionPoint instance + """ + # Ignore child_region_id if present in old data + return cls(region_index=data["region_index"], input_index=data["input_index"]) + + def __str__(self) -> str: + """String representation for debugging.""" + return ( + f"ChildRegionInputInsertionPoint(region={self.region_index}, input={self.input_index})" + ) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a child region input insertion point to actual tensor names for a matching region. + + Converts pattern-relative child region index and input index to the actual tensor + name at that child region's input boundary, then resolves to all node inputs that + consume that tensor. + + Args: + region: The parent region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names. + Returns empty set for LEAF regions (no children). + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + if graph is None: + raise ValueError("graph parameter is required") + + # LEAF regions have no child boundaries + if region.get_type() == RegionType.LEAF: + return set() + + # Get sorted child regions (must match order in RegionPattern._compute_signature_recursive) + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + # Map from pattern-relative child index to actual child region + resolved_ips = set() + assert self.region_index < len(children_regions), "Child region index out of range" + child_region = children_regions[self.region_index] + assert self.input_index < len(child_region.get_inputs()), "Input index out of range" + # Resolve the input tensor name using input_index + tensor_name = child_region.get_inputs()[self.input_index] + assert tensor_name is not None, "Tensor name is required" + resolved_ips.update(resolve_region_io_insertion_points(child_region, graph, tensor_name)) + + return resolved_ips + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionInputInsertionPoint"]: + """Collect all valid child region input insertion points from a region. + + For COMPOSITE regions, analyzes each child region and identifies all valid + input tensors where Q/DQ pairs could be inserted at child boundaries. + Returns empty list for LEAF regions (no children). + + Args: + region: The parent region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of ChildRegionInputInsertionPoint objects representing valid insertion locations + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + child_region_input_insertion_points = [] + + # Only COMPOSITE regions have child boundaries for Q/DQ insertion + if region.get_type() != RegionType.LEAF: + # Get all child regions, sorted for deterministic ordering + # Must match sorting in _compute_signature_recursive to ensure + # insertion point indices align with pattern structure + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + + for local_idx, child_region in enumerate(children_regions): + # Create insertion point for each input tensor of the child region + for input_idx, inp in enumerate(child_region.get_inputs()): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + point = ChildRegionInputInsertionPoint( + # Child region index within parent pattern + region_index=local_idx, + # Input index within child region + input_index=input_idx, + ) + child_region_input_insertion_points.append(point) + + return child_region_input_insertion_points + + +@dataclass(frozen=True) +class RegionOutputInsertionPoint: + """Pattern-relative Q/DQ insertion point at an output location. + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + **Use Case:** + Parent regions can: + - Add Q/DQ at child region output boundaries + - Add Q/DQ at node outputs within the region + - Control quantization precision as data flows through the region hierarchy + + **Resolution Process:** + 1. Pattern-relative indices are defined once + 2. If output is from a child region: use region_index (node_index is None) + - region_index identifies which child region (in sorted order) + - output_index identifies which output tensor of that child region + 3. If output is from a node: use node_index (region_index is None) + - node_index identifies which node (in sorted order) + - output_index identifies which output tensor of that node + 4. Resolves to the actual tensor name at that output location + + **Examples:** + - RegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + → First output of the first child region + - RegionOutputInsertionPoint(region_index=None, node_index=2, output_index=1) + → Second output of the third node + + **Note:** Exactly one of region_index or node_index must be set (the other must be None). + + **Attributes:** + - region_index: Index of child region within parent pattern (0-based), or None + - node_index: Index of node within the region (0-based), or None + - output_index: Index of the output tensor (0-based) + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Index of child region within parent pattern (0-based), or None + region_index: int | None + # Index of node within the region (0-based), or None + node_index: int | None + # Index of the output tensor (0-based) + output_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "region_index": self.region_index, + "node_index": self.node_index, + "output_index": self.output_index, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "RegionOutputInsertionPoint": + """Create from dictionary. + + Args: + data: Dictionary with 'region_index', 'node_index', and 'output_index' keys + + Returns: + RegionOutputInsertionPoint instance + """ + return cls( + region_index=data.get("region_index"), + node_index=data.get("node_index"), + output_index=data["output_index"], + ) + + def __str__(self) -> str: + """String representation for debugging.""" + if self.region_index is not None: + return f"RegionOutputInsertionPoint(region={self.region_index}, output={self.output_index})" + else: + return f"RegionOutputInsertionPoint(node={self.node_index}, output={self.output_index})" + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a region output insertion point to actual tensor names for a matching region. + + Converts pattern-relative indices to the actual tensor name at an output location: + - If region_index is set: Resolves to a child region's output tensor + - If node_index is set: Resolves to a node's output tensor + + Then identifies all node inputs that consume that output tensor. + + Args: + region: The region instance matching this pattern + graph: The ONNX graph containing the nodes + + Returns: + Set of ResolvedInsertionPoint objects with actual tensor names + """ + if graph is None: + raise ValueError("graph parameter is required") + + # Get sorted nodes for node output resolution + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + + # Resolve each region output insertion point from the scheme to actual tensor names + resolved_ips = set() + # Handle child region outputs (region_index is set) + if self.region_index is not None: + assert self.region_index < len(children_regions), "Region index out of range" + child_region = children_regions[self.region_index] + assert self.output_index < len(child_region.get_outputs()), "Output index out of range" + tensor_name = child_region.get_outputs()[self.output_index] + assert tensor_name is not None, "Invalid tensor name" + resolved_ips.update( + resolve_region_io_insertion_points(child_region, graph, tensor_name) + ) + # Handle node outputs (node_index is set) + elif self.node_index is not None: + assert self.node_index < len(node_indices), "Node index out of range" + node_idx = node_indices[self.node_index] + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + assert self.output_index < len(node.outputs), "Output index out of range" + tensor = node.outputs[self.output_index] + assert tensor is not None, "Invalid tensor name" + assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" + resolved_ips.update(resolve_region_io_insertion_points(None, graph, tensor.name)) + return resolved_ips + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["RegionOutputInsertionPoint"]: + """Collect all valid region output insertion points from a region. + + Identifies all valid output tensors (from child regions or nodes) that leave + the region boundary and could have Q/DQ pairs inserted. Only includes outputs + that are actual region outputs (not consumed internally). + + For COMPOSITE regions: + - Collects child region outputs that are also region outputs + - Collects node outputs that are region outputs + + For LEAF regions: + - Only collects node outputs that are region outputs + + Args: + region: The region to collect insertion points from + graph: The ONNX graph containing the nodes + + Returns: + List of RegionOutputInsertionPoint objects representing valid insertion locations + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + nodes_list = list(graph.nodes) + node_indices = sorted(region.get_nodes()) + region_outputs_set = set(region.get_outputs()) + + # Only include outputs that are actual region outputs (leave the region) + region_output_insertion_points = [] + if region.get_type() != RegionType.LEAF: + # For COMPOSITE regions: check if child region output is a region output + children_regions = region.get_children() + children_regions = sorted( + children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) + ) + for local_idx, child_region in enumerate(children_regions): + for output_idx, out in enumerate(child_region.get_outputs()): + if out not in region_outputs_set: + continue + if skip_invalid_insertion_points(graph, out, child_region): + continue + point = RegionOutputInsertionPoint( + region_index=local_idx, + node_index=None, + output_index=output_idx, + ) + region_output_insertion_points.append(point) + # For all regions: check if node output is a region output + for local_idx, node_idx in enumerate(node_indices): + assert node_idx < len(nodes_list), "Node index out of range" + node = nodes_list[node_idx] + for output_idx, out in enumerate(node.outputs): + # Skip if tensor doesn't have a valid name + if not (hasattr(out, "name") and out.name): + continue + # Skip if this output is not a region output (i.e., it's consumed internally) + if out.name not in region_outputs_set: + continue + # Skip if insertion point is invalid (wrong dtype, small size, etc.) + if skip_invalid_insertion_points(graph, out.name, node): + continue + # Create insertion point for valid output tensor + point = RegionOutputInsertionPoint( + region_index=None, + node_index=local_idx, + output_index=output_idx, + ) + region_output_insertion_points.append(point) + + return region_output_insertion_points + + +InsertionPointType = ( + NodeInputInsertionPoint | ChildRegionInputInsertionPoint | RegionOutputInsertionPoint +) + + +def skip_invalid_insertion_points( + graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" +) -> bool: + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ + from modelopt.onnx.quantization.autotune.common import Region + + if isinstance(region_or_node, Region): + node_indices = region_or_node.get_all_nodes_recursive() + nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] + else: + assert isinstance(region_or_node, gs.Node) + nodes = [region_or_node] + + for node in nodes: + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time + if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + return True + if node.op in ["Relu", "LeakyRelu", "Softmax"]: + # Conv -> ReLU/LeakyRelu/Softmax + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Conv -> BatchNormalization -> ReLU/LeakyRelu/Softmax + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op == "BatchNormalization": + assert len(producer.inputs) >= 1, ( + "BN node should have more than one inputs" + ) + if len(producer.inputs[0].inputs) == 1: + producer = producer.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Conv -> BatchNormalization -> ReLU/LeakyRelu/Softmax + if node.op == "BatchNormalization": + assert len(node.inputs) >= 1, "BN node should have more than one inputs" + if len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Filter 1: out boolean operations + if node.op in BOOL_OPERATIONS: + return True + # Filter 2: out shape operations + if node.op in SHAPE_OPERATIONS: + return True + # Filter 3: Skip operation-specific non-quantizable inputs + if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: + return True + if node.op in ["Conv", "Gemm"] and input_idx >= 2: + return True + # Filter 4: Skip non-floating-point tensors (int/bool indices, masks, etc.) + if hasattr(inp, "dtype") and inp.dtype not in [ + None, + np.float32, + np.float16, + np.float64, + ]: + return True + # Filter 5: Skip small tensors (scalars, small vectors) + if hasattr(inp, "shape") and inp.shape is not None: + if all(isinstance(s, int) for s in inp.shape): + if np.prod(inp.shape) < 8: + return True + return False + + +def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: + """Check if a region contains major quantizable operations. + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + # only check leaf regions for quantizable operations + if region.get_type() == RegionType.LEAF: + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops.intersection(MAJOR_QUANTIZABLE_OPERATIONS)) + return True + + +def resolve_region_io_insertion_points( + region: "Region | None", graph: gs.Graph, tensor_name: str +) -> set[ResolvedInsertionPoint]: + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + **Use Cases:** + - Child region inputs: Find all nodes inside the child that consume the input tensor + - Child region outputs: Find all nodes outside the child that consume the output tensor + - Node outputs: Find all nodes that consume the tensor (region can be None) + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ + resolved_insertion_points = set() + tensor_users_map: dict[str, list[int]] = {} + if hasattr(graph, "tensor_users_map"): + tensor_users_map = graph.tensor_users_map + if not tensor_users_map: + tensor_users_map = get_tensor_consumer_node_indices(graph) + + if region is not None: + for node_idx in region.get_all_nodes_recursive(): + assert node_idx < len(graph.nodes), "Node index out of range" + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if inp.name == tensor_name: + ip = ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + resolved_insertion_points.add(ip) + + if tensor_name in tensor_users_map: + for node_idx in tensor_users_map[tensor_name]: + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if inp.name == tensor_name: + ip = ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + resolved_insertion_points.add(ip) + + return resolved_insertion_points + + +def merge_resolved_insertion_points( + graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] +) -> set[ResolvedInsertionPoint]: + """Optimize insertion points by merging node-specific insertions into tensor-level insertions. + + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + **Optimization Logic:** + - For each tensor with multiple node-specific insertion points: + - If ALL users of the tensor have insertion points → merge to tensor-level insertion + - If SOME users have insertion points → keep node-specific insertions + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible + """ + tensor_users_map = get_tensor_consumer_node_indices(graph) + node_input_insertion_points = { + ip for ip in resolved_insertion_points if ip.node_index is not None + } + tensor_names = {ip.tensor_name for ip in node_input_insertion_points} + + results = resolved_insertion_points.difference(node_input_insertion_points) + for tensor_name in tensor_names: + all_users = set(tensor_users_map[tensor_name]) + qdq_users = { + user for user in node_input_insertion_points if user.tensor_name == tensor_name + } + qdq_user_ids = set({user.node_index for user in qdq_users}) + if all_users == qdq_user_ids: + results.add( + ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) + ) + else: + results.update(qdq_users) + + return results diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df..f05a08bfa 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,35 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + + if isinstance(graph, gs.Graph): + for node_idx, node in enumerate(graph.nodes): + for t in node.inputs: + name = getattr(t, "name", None) + if not name: + continue + tensor_consumer_map[name].append(node_idx) + return tensor_consumer_map + + # onnx.GraphProto case: node.input is repeated string + for node_idx, node in enumerate(graph.node): + for input_name in node.input: + if not input_name: + continue + tensor_consumer_map[input_name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py new file mode 100644 index 000000000..d71524442 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -0,0 +1,1331 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive tests for common data structures in the autotuner. + +Tests: +1. InsertionPoint classes (NodeInputInsertionPoint, RegionOutputInsertionPoint, ChildRegionInputInsertionPoint) +2. InsertionScheme serialization/deserialization +3. InsertionScheme hashing and equality +4. InsertionScheme properties and methods +5. PatternSchemes management +6. Utility functions (skip_invalid_insertion_points, has_quantizable_operations, etc.) +7. Resolve and collect_from methods for all InsertionPoint types +""" + +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import onnx_graphsurgeon as gs + +from modelopt.onnx.quantization.autotune.common import ( + ChildRegionInputInsertionPoint, + InsertionScheme, + NodeInputInsertionPoint, + Region, + RegionOutputInsertionPoint, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + has_quantizable_operations, + merge_resolved_insertion_points, + resolve_region_io_insertion_points, + skip_invalid_insertion_points, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class TestNodeInputInsertionPoint(unittest.TestCase): + """Test NodeInputInsertionPoint functionality.""" + + def test_creation(self): + """Test creating NodeInputInsertionPoint.""" + point = NodeInputInsertionPoint(node_index=5, input_index=2) + assert point.node_index == 5 + assert point.input_index == 2 + + def test_immutability(self): + """Test that NodeInputInsertionPoint is immutable (frozen).""" + point = NodeInputInsertionPoint(node_index=1, input_index=0) + passed = False + try: + point.node_index = 2 + except AttributeError: + passed = True + assert passed, "NodeInputInsertionPoint should be immutable" + + def test_equality(self): + """Test equality comparison.""" + point1 = NodeInputInsertionPoint(node_index=3, input_index=1) + point2 = NodeInputInsertionPoint(node_index=3, input_index=1) + point3 = NodeInputInsertionPoint(node_index=3, input_index=2) + + assert point1 == point2 + assert point1 != point3 + + def test_hashable(self): + """Test that points can be used in sets and dicts.""" + point1 = NodeInputInsertionPoint(node_index=1, input_index=0) + point2 = NodeInputInsertionPoint(node_index=1, input_index=0) + point3 = NodeInputInsertionPoint(node_index=2, input_index=0) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are the same + + def test_serialization(self): + """Test to_dict and from_dict.""" + point = NodeInputInsertionPoint(node_index=7, input_index=3) + + data = point.to_dict() + assert data["node_index"] == 7 + assert data["input_index"] == 3 + + restored = NodeInputInsertionPoint.from_dict(data) + assert point == restored + + def test_string_representation(self): + """Test __str__ method.""" + point = NodeInputInsertionPoint(node_index=2, input_index=1) + s = str(point) + assert "2" in s + assert "1" in s + + +class TestRegionOutputInsertionPoint(unittest.TestCase): + """Test RegionOutputInsertionPoint functionality.""" + + def test_creation_with_region_index(self): + """Test creating with region_index (child region output).""" + point = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + assert point.region_index == 2 + assert point.node_index is None + assert point.output_index == 1 + + def test_creation_with_node_index(self): + """Test creating with node_index (node output).""" + point = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + assert point.region_index is None + assert point.node_index == 5 + assert point.output_index == 0 + + def test_immutability(self): + """Test that RegionOutputInsertionPoint is immutable (frozen).""" + point = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + passed = False + try: + point.region_index = 2 + except AttributeError: + passed = True + assert passed, "RegionOutputInsertionPoint should be immutable" + + def test_equality(self): + """Test equality comparison.""" + point1 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point2 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point3 = RegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) + + assert point1 == point2 + assert point1 != point3 + + def test_hashable(self): + """Test that points can be used in sets and dicts.""" + point1 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point2 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point3 = RegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are the same + + def test_serialization_region_index(self): + """Test serialization with region_index.""" + point = RegionOutputInsertionPoint(region_index=3, node_index=None, output_index=2) + + data = point.to_dict() + assert data["region_index"] == 3 + assert data["node_index"] is None + assert data["output_index"] == 2 + + restored = RegionOutputInsertionPoint.from_dict(data) + assert point == restored + + def test_serialization_node_index(self): + """Test serialization with node_index.""" + point = RegionOutputInsertionPoint(region_index=None, node_index=7, output_index=1) + + data = point.to_dict() + assert data["region_index"] is None + assert data["node_index"] == 7 + assert data["output_index"] == 1 + + restored = RegionOutputInsertionPoint.from_dict(data) + assert point == restored + + def test_string_representation(self): + """Test __str__ method.""" + point1 = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + s1 = str(point1) + assert "region" in s1.lower() + assert "2" in s1 + + point2 = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + s2 = str(point2) + assert "node" in s2.lower() + assert "5" in s2 + + +class TestChildRegionInputInsertionPoint(unittest.TestCase): + """Test ChildRegionInputInsertionPoint functionality.""" + + def test_creation(self): + """Test creating ChildRegionInputInsertionPoint.""" + point = ChildRegionInputInsertionPoint(region_index=3, input_index=1) + assert point.region_index == 3 + assert point.input_index == 1 + + def test_immutability(self): + """Test that ChildRegionInputInsertionPoint is immutable (frozen).""" + point = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + passed = False + try: + point.region_index = 2 + except AttributeError: + passed = True + assert passed, "ChildRegionInputInsertionPoint should be immutable" + + def test_equality(self): + """Test equality comparison.""" + point1 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + point2 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + point3 = ChildRegionInputInsertionPoint(region_index=2, input_index=1) + + assert point1 == point2 + assert point1 != point3 + + def test_hashable(self): + """Test that points can be used in sets and dicts.""" + point1 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + point2 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + point3 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) + + point_set = {point1, point2, point3} + assert len(point_set) == 2 # point1 and point2 are the same + + def test_serialization(self): + """Test to_dict and from_dict.""" + point = ChildRegionInputInsertionPoint(region_index=5, input_index=2) + + data = point.to_dict() + assert data["region_index"] == 5 + assert data["input_index"] == 2 + + restored = ChildRegionInputInsertionPoint.from_dict(data) + assert point == restored + + def test_string_representation(self): + """Test __str__ method.""" + point = ChildRegionInputInsertionPoint(region_index=2, input_index=1) + s = str(point) + assert "2" in s + assert "1" in s + + +class TestInsertionScheme(unittest.TestCase): + """Test InsertionScheme functionality.""" + + def test_empty_scheme(self): + """Test empty InsertionScheme.""" + scheme = InsertionScheme() + + assert scheme.is_empty + assert len(scheme.node_inputs) == 0 + assert len(scheme.child_region_inputs) == 0 + assert len(scheme.region_outputs) == 0 + assert not scheme.error + + def test_scheme_with_node_inputs(self): + """Test scheme with node input insertion points.""" + scheme = InsertionScheme() + scheme.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + assert not scheme.is_empty + assert len(scheme.node_inputs) == 2 + + def test_scheme_with_region_outputs(self): + """Test scheme with region output insertion points.""" + scheme = InsertionScheme() + scheme.region_outputs = [ + RegionOutputInsertionPoint(None, 0, 0), + RegionOutputInsertionPoint(1, None, 0), + ] + + assert not scheme.is_empty + assert len(scheme.region_outputs) == 2 + + def test_scheme_with_composite_regions(self): + """Test scheme with composite region insertion points.""" + scheme = InsertionScheme() + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint(0, 0), + ChildRegionInputInsertionPoint(1, 0), + ] + + assert not scheme.is_empty + assert len(scheme.child_region_inputs) == 2 + + def test_scheme_hash_empty(self): + """Test hash of empty scheme.""" + scheme1 = InsertionScheme() + scheme2 = InsertionScheme() + + assert scheme1.hash == scheme2.hash + + def test_scheme_hash_with_points(self): + """Test hash with insertion points.""" + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme2 = InsertionScheme() + scheme2.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme3 = InsertionScheme() + scheme3.node_inputs = [ + NodeInputInsertionPoint(0, 0), + NodeInputInsertionPoint(2, 0), # Different + ] + + assert scheme1.hash == scheme2.hash + assert scheme1.hash != scheme3.hash + + def test_scheme_hash_order_independent(self): + """Test that hash is independent of insertion point order.""" + scheme1 = InsertionScheme() + scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] + + scheme2 = InsertionScheme() + scheme2.node_inputs = [ + NodeInputInsertionPoint(1, 0), + NodeInputInsertionPoint(0, 0), # Reversed order + ] + + # Hash should be the same regardless of order + assert scheme1.hash == scheme2.hash + + def test_serialization_empty(self): + """Test serialization of empty scheme.""" + scheme = InsertionScheme() + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert restored.is_empty + assert restored.latency_ms == float("inf") + assert not restored.error + + def test_serialization_full(self): + """Test serialization with all types of insertion points.""" + scheme = InsertionScheme() + scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] + scheme.region_outputs = [RegionOutputInsertionPoint(None, 0, 0)] + scheme.latency_ms = 12.5 + scheme.error = False + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert len(restored.node_inputs) == 1 + assert len(restored.child_region_inputs) == 1 + assert len(restored.region_outputs) == 1 + assert restored.latency_ms == 12.5 + assert not restored.error + + def test_serialization_with_error(self): + """Test serialization with error flag.""" + scheme = InsertionScheme() + scheme.error = True + scheme.latency_ms = float("inf") + + data = scheme.to_dict() + restored = InsertionScheme.from_dict(data) + + assert restored.error + assert restored.latency_ms == float("inf") + + +# ============================================================================= +# Helper functions for creating mock graphs +# ============================================================================= + + +def _create_mock_tensor(name: str, dtype=np.float32, shape=None): + """Create a mock tensor with the specified properties.""" + tensor = MagicMock() + tensor.name = name + tensor.dtype = dtype + tensor.shape = shape if shape is not None else [1, 3, 224, 224] + tensor.inputs = [] + return tensor + + +def _create_mock_node(op: str, inputs: list, outputs: list, name: str = ""): + """Create a mock node with the specified properties.""" + node = MagicMock(spec=gs.Node) + node.op = op + node.name = name + node.inputs = inputs + node.outputs = outputs + return node + + +def _create_simple_graph(): + """Create a mock graph with Conv -> BatchNorm -> Relu -> MaxPool pattern. + + Graph structure: + input -> Conv -> conv_out -> BatchNorm -> bn_out -> Relu -> relu_out -> MaxPool -> pool_out + + Node indices: + 0: Conv + 1: BatchNormalization + 2: Relu + 3: MaxPool + """ + # Create tensors with realistic shapes + input_tensor = _create_mock_tensor("input", np.float32, [1, 3, 224, 224]) + weight_tensor = _create_mock_tensor("conv_weight", np.float32, [64, 3, 3, 3]) + bias_tensor = _create_mock_tensor("conv_bias", np.float32, [64]) + conv_output = _create_mock_tensor("conv_out", np.float32, [1, 64, 222, 222]) + + # BatchNorm parameters + bn_scale = _create_mock_tensor("bn_scale", np.float32, [64]) + bn_bias = _create_mock_tensor("bn_bias", np.float32, [64]) + bn_mean = _create_mock_tensor("bn_mean", np.float32, [64]) + bn_var = _create_mock_tensor("bn_var", np.float32, [64]) + bn_output = _create_mock_tensor("bn_out", np.float32, [1, 64, 222, 222]) + + relu_output = _create_mock_tensor("relu_out", np.float32, [1, 64, 222, 222]) + pool_output = _create_mock_tensor("pool_out", np.float32, [1, 64, 111, 111]) + + # Create nodes + conv_node = _create_mock_node( + "Conv", [input_tensor, weight_tensor, bias_tensor], [conv_output], "conv1" + ) + bn_node = _create_mock_node( + "BatchNormalization", + [conv_output, bn_scale, bn_bias, bn_mean, bn_var], + [bn_output], + "bn1", + ) + relu_node = _create_mock_node("Relu", [bn_output], [relu_output], "relu1") + pool_node = _create_mock_node("MaxPool", [relu_output], [pool_output], "pool1") + + # Link tensors to their producer nodes + conv_output.inputs = [conv_node] + bn_output.inputs = [bn_node] + relu_output.inputs = [relu_node] + pool_output.inputs = [pool_node] + input_tensor.inputs = [] + weight_tensor.inputs = [] + bias_tensor.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv_node, bn_node, relu_node, pool_node] + graph.inputs = [input_tensor] + graph.outputs = [pool_output] + + tensors = { + "input": input_tensor, + "conv_weight": weight_tensor, + "conv_bias": bias_tensor, + "conv_out": conv_output, + "bn_out": bn_output, + "relu_out": relu_output, + "pool_out": pool_output, + } + + return graph, tensors + + +def _create_residual_graph(): + """Create a mock graph with a residual block pattern (skip connection). + + Graph structure: + input ─────────────────────────────┐ + │ │ + ▼ │ + Conv1 -> conv1_out │ + │ │ + ▼ │ + Relu1 -> relu1_out │ + │ │ + ▼ │ + Conv2 -> conv2_out │ + │ │ + ▼ ▼ + Add (conv2_out + input) -> add_out + │ + ▼ + Relu2 -> output + + Node indices: + 0: Conv1 + 1: Relu1 + 2: Conv2 + 3: Add + 4: Relu2 + """ + # Create tensors + input_tensor = _create_mock_tensor("input", np.float32, [1, 64, 56, 56]) + + # First conv branch + weight1 = _create_mock_tensor("conv1_weight", np.float32, [64, 64, 3, 3]) + conv1_out = _create_mock_tensor("conv1_out", np.float32, [1, 64, 56, 56]) + relu1_out = _create_mock_tensor("relu1_out", np.float32, [1, 64, 56, 56]) + + # Second conv + weight2 = _create_mock_tensor("conv2_weight", np.float32, [64, 64, 3, 3]) + conv2_out = _create_mock_tensor("conv2_out", np.float32, [1, 64, 56, 56]) + + # Add and final relu + add_out = _create_mock_tensor("add_out", np.float32, [1, 64, 56, 56]) + output = _create_mock_tensor("output", np.float32, [1, 64, 56, 56]) + + # Create nodes + conv1_node = _create_mock_node("Conv", [input_tensor, weight1], [conv1_out], "conv1") + relu1_node = _create_mock_node("Relu", [conv1_out], [relu1_out], "relu1") + conv2_node = _create_mock_node("Conv", [relu1_out, weight2], [conv2_out], "conv2") + add_node = _create_mock_node("Add", [conv2_out, input_tensor], [add_out], "add1") + relu2_node = _create_mock_node("Relu", [add_out], [output], "relu2") + + # Link tensors to their producer nodes + conv1_out.inputs = [conv1_node] + relu1_out.inputs = [relu1_node] + conv2_out.inputs = [conv2_node] + add_out.inputs = [add_node] + output.inputs = [relu2_node] + input_tensor.inputs = [] + weight1.inputs = [] + weight2.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv1_node, relu1_node, conv2_node, add_node, relu2_node] + graph.inputs = [input_tensor] + graph.outputs = [output] + + tensors = { + "input": input_tensor, + "conv1_weight": weight1, + "conv1_out": conv1_out, + "relu1_out": relu1_out, + "conv2_weight": weight2, + "conv2_out": conv2_out, + "add_out": add_out, + "output": output, + } + + return graph, tensors + + +# ============================================================================= +# Utility Function Tests +# ============================================================================= + + +class TestSkipInvalidInsertionPoints(unittest.TestCase): + """Test skip_invalid_insertion_points function.""" + + def test_skip_bool_operations(self): + """Test that boolean operations are skipped.""" + graph, _ = _create_simple_graph() + + # Create a node with boolean operation + bool_tensor = _create_mock_tensor("bool_input", np.float32) + bool_node = _create_mock_node("Equal", [bool_tensor], []) + + result = skip_invalid_insertion_points(graph, "bool_input", bool_node) + assert result is True + + def test_skip_shape_operations(self): + """Test that shape operations are skipped.""" + graph, _ = _create_simple_graph() + + shape_tensor = _create_mock_tensor("shape_input", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], []) + + result = skip_invalid_insertion_points(graph, "shape_input", shape_node) + assert result is True + + def test_skip_conv_weight_input(self): + """Test that Conv weight inputs (index >= 1) are skipped.""" + graph, tensors = _create_simple_graph() + conv_node = graph.nodes[0] + + # Weight is at input index 1 + result = skip_invalid_insertion_points(graph, "conv_weight", conv_node) + assert result is True + + def test_allow_conv_data_input(self): + """Test that Conv data input (index 0) is allowed.""" + graph, tensors = _create_simple_graph() + + # Create a MatMul node that consumes the input tensor (not Conv-related skip) + input_tensor = _create_mock_tensor("matmul_input", np.float32, [1, 3, 224, 224]) + matmul_node = _create_mock_node("MatMul", [input_tensor], []) + + result = skip_invalid_insertion_points(graph, "matmul_input", matmul_node) + assert result is False + + def test_skip_non_float_tensors(self): + """Test that non-floating-point tensors are skipped.""" + graph, _ = _create_simple_graph() + + # Create int tensor + int_tensor = _create_mock_tensor("int_input", np.int32) + node = _create_mock_node("Add", [int_tensor], []) + + result = skip_invalid_insertion_points(graph, "int_input", node) + assert result is True + + def test_skip_small_tensors(self): + """Test that small tensors (< 8 elements) are skipped.""" + graph, _ = _create_simple_graph() + + # Create small tensor (scalar) + small_tensor = _create_mock_tensor("small", np.float32, [1]) + node = _create_mock_node("Add", [small_tensor], []) + + result = skip_invalid_insertion_points(graph, "small", node) + assert result is True + + def test_allow_large_float_tensors(self): + """Test that large floating-point tensors are allowed.""" + graph, _ = _create_simple_graph() + + # Create large float tensor + large_tensor = _create_mock_tensor("large", np.float32, [1, 64, 32, 32]) + node = _create_mock_node("Add", [large_tensor], []) + + result = skip_invalid_insertion_points(graph, "large", node) + assert result is False + + def test_skip_bn_non_data_inputs(self): + """Test that BatchNormalization non-data inputs are skipped.""" + graph, tensors = _create_simple_graph() + bn_node = graph.nodes[1] # BatchNormalization node + + # Scale is at input index 1, should be skipped + result = skip_invalid_insertion_points(graph, "bn_scale", bn_node) + assert result is True + + def test_with_region(self): + """Test skip_invalid_insertion_points with a Region containing multiple nodes.""" + graph, tensors = _create_simple_graph() + + # Create a region containing Conv and BatchNorm nodes + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + region.add_node(1) # BatchNorm node + + # Create a shape operation node and add to graph + shape_tensor = _create_mock_tensor("shape_input", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], []) + graph.nodes.append(shape_node) + region.add_node(4) # Add the shape node to region + + result = skip_invalid_insertion_points(graph, "shape_input", region) + assert result is True + + def test_skip_conv_bn_relu_fusion(self): + """Test that Conv->BN->Relu fusion patterns are skipped at intermediate points.""" + graph, tensors = _create_simple_graph() + relu_node = graph.nodes[2] # Relu node + + # Relu input (bn_out) should be skipped when preceded by Conv->BN + result = skip_invalid_insertion_points(graph, "bn_out", relu_node) + assert result is True + + def test_residual_block_add_inputs(self): + """Test insertion points in a residual block with skip connection.""" + graph, tensors = _create_residual_graph() + add_node = graph.nodes[3] # Add node + + # Add's first input (conv2_out) should be allowed + result = skip_invalid_insertion_points(graph, "conv2_out", add_node) + assert result is False + + # Add's second input (skip connection input) should also be allowed + result = skip_invalid_insertion_points(graph, "input", add_node) + assert result is False + + +class TestHasQuantizableOperations(unittest.TestCase): + """Test has_quantizable_operations function.""" + + def test_leaf_with_conv(self): + """Test LEAF region with Conv operation.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_leaf_with_maxpool(self): + """Test LEAF region with MaxPool (a major quantizable op).""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(3) # MaxPool node + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_leaf_with_relu_only(self): + """Test LEAF region with only Relu.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(2) # Relu node only (index 2 in new graph) + + result = has_quantizable_operations(region, graph) + assert result is True # Relu is in MAJOR_QUANTIZABLE_OPERATIONS + + def test_leaf_with_conv_bn_relu(self): + """Test LEAF region with Conv->BN->Relu pattern.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_leaf_without_quantizable_ops(self): + """Test LEAF region without major quantizable operations.""" + # Create graph with only shape operations + shape_tensor = _create_mock_tensor("input", np.float32) + output_tensor = _create_mock_tensor("output", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], [output_tensor]) + transpose_node = _create_mock_node("Transpose", [output_tensor], []) + + graph = MagicMock(spec=gs.Graph) + graph.nodes = [shape_node, transpose_node] + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) + region.add_node(1) + + result = has_quantizable_operations(region, graph) + assert result is False + + def test_composite_region_always_true(self): + """Test that COMPOSITE regions always return True.""" + graph, _ = _create_simple_graph() + + region = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + # Don't add any nodes - COMPOSITE regions assume children have quantizable ops + + result = has_quantizable_operations(region, graph) + assert result is True + + def test_residual_block_has_quantizable_ops(self): + """Test residual block with Add operation.""" + graph, _ = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(3) # Add node + + result = has_quantizable_operations(region, graph) + assert result is True # Add is in MAJOR_QUANTIZABLE_OPERATIONS + + +class TestResolveRegionIOInsertionPoints(unittest.TestCase): + """Test resolve_region_io_insertion_points function.""" + + def test_resolve_with_region(self): + """Test resolving with a region containing Conv->BN->Relu.""" + graph, tensors = _create_simple_graph() + + # Set up tensor_users_map: conv_out is consumed by BatchNorm (node 1) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(2) # Relu node + + result = resolve_region_io_insertion_points(region, graph, "relu_out") + + assert len(result) >= 1 + assert any(ip.tensor_name == "relu_out" for ip in result) + + def test_resolve_without_region(self): + """Test resolving without a region (None) for tensor-level insertion.""" + graph, _ = _create_simple_graph() + + # Set up tensor_users_map: bn_out is consumed by Relu (node 2) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + + result = resolve_region_io_insertion_points(None, graph, "relu_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "relu_out" + assert ip.node_index == 3 + assert ip.input_index == 0 + + def test_resolve_tensor_not_found(self): + """Test resolving a tensor that has no users.""" + graph, _ = _create_simple_graph() + graph.tensor_users_map = {} + + result = resolve_region_io_insertion_points(None, graph, "nonexistent") + + assert len(result) == 0 + + def test_resolve_residual_skip_connection(self): + """Test resolving input tensor used by both Conv1 and Add (skip connection).""" + graph, tensors = _create_residual_graph() + + # Input tensor is used by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = {"input": [0, 3]} + + result = resolve_region_io_insertion_points(None, graph, "input") + + # Should find both consumers + assert len(result) == 2 + node_indices = {ip.node_index for ip in result} + assert 0 in node_indices # Conv1 + assert 3 in node_indices # Add + + def test_resolve_with_multiple_consumers(self): + """Test resolving tensor with multiple consumers in a region.""" + graph, tensors = _create_residual_graph() + + # relu1_out feeds conv2 (node 2) + graph.tensor_users_map = {"relu1_out": [2]} + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(2) # Conv2 + + result = resolve_region_io_insertion_points(region, graph, "relu1_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "relu1_out" + assert ip.node_index == 2 + + +class TestMergeResolvedInsertionPoints(unittest.TestCase): + """Test merge_resolved_insertion_points function.""" + + def test_merge_all_users(self): + """Test merging when all users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by BatchNorm (node 1) + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "conv_out" + assert merged.node_index is None + assert merged.input_index is None + + def test_no_merge_partial_users(self): + """Test no merging when only some users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by nodes 1 and 2, but only node 1 has IP + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1, 2]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT be merged - keep node-specific + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 1 # Still node-specific + + def test_preserve_tensor_level_insertions(self): + """Test that existing tensor-level insertions are preserved.""" + graph, _ = _create_simple_graph() + + # Already tensor-level insertion + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=None, input_index=None), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "input" + assert ip.node_index is None + + def test_merge_residual_skip_connection(self): + """Test merging with residual block where input has two users.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # If we have insertion points for both, they should merge + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + ResolvedInsertionPoint(tensor_name="input", node_index=3, input_index=1), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "input" + assert merged.node_index is None + + def test_no_merge_residual_partial(self): + """Test no merging in residual block when only one branch has insertion point.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # Only Conv1 has an insertion point + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT merge - only one of two users has IP + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 0 # Still node-specific + + +# ============================================================================= +# Resolve Method Tests +# ============================================================================= + + +class TestNodeInputInsertionPointResolve(unittest.TestCase): + """Test NodeInputInsertionPoint.resolve() method.""" + + def test_resolve_simple(self): + """Test resolving a simple node input for Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + region.add_node(1) # BatchNorm node + region.add_node(2) # Relu node + region.add_node(3) # MaxPool node + + # Create insertion point for first input of first node (Conv) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_conv_includes_weight(self): + """Test that resolving Conv input also includes weight.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv node + + # Create insertion point for first input of Conv (should also add weight) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + + result = ip.resolve(region, graph) + + # Should include both data input and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "input" in tensor_names + assert "conv_weight" in tensor_names + + def test_resolve_relu_input(self): + """Test resolving Relu input in the middle of the chain.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + + # Relu is at local index 2, input 0 is bn_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + + result = ip.resolve(region, graph) + + assert len(result) == 1 + rip = next(iter(result)) + assert rip.tensor_name == "bn_out" + + def test_resolve_residual_conv_input(self): + """Test resolving Conv input in residual block.""" + graph, tensors = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + + # Conv2 is at local index 2, input 0 is relu1_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + + result = ip.resolve(region, graph) + + # Conv includes both data and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "relu1_out" in tensor_names + assert "conv2_weight" in tensor_names + + +class TestChildRegionInputInsertionPointResolve(unittest.TestCase): + """Test ChildRegionInputInsertionPoint.resolve() method.""" + + def test_resolve_composite_region(self): + """Test resolving child region input in COMPOSITE region.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"input": [0]} + + # Create parent (COMPOSITE) with child (LEAF) containing Conv->BN->Relu + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.inputs = ["input"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + + result = ip.resolve(parent, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_leaf_returns_empty(self): + """Test that LEAF regions return empty set.""" + graph, _ = _create_simple_graph() + + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + leaf.add_node(0) + + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + + result = ip.resolve(leaf, graph) + + assert len(result) == 0 + + def test_resolve_multiple_children(self): + """Test resolving child inputs in COMPOSITE with multiple children.""" + graph, tensors = _create_residual_graph() + # input is consumed by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + + # Create parent with two child regions + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + + # First child: Conv1 (consumes "input") + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child1.inputs = ["input"] + child1.add_node(0) # Conv1 + + # Second child: Relu1 (consumes "relu1_out") + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + child2.inputs = ["relu1_out"] + child2.add_node(2) # Relu1 + + parent.add_child(child1) + parent.add_child(child2) + + # Resolve input of first child (region_index=0) - "input" tensor + ip1 = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result1 = ip1.resolve(parent, graph) + + assert len(result1) >= 1 + assert any(rip.tensor_name == "input" for rip in result1) + + # Resolve input of second child (region_index=1) - "relu1_out" tensor + ip2 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + result2 = ip2.resolve(parent, graph) + + assert len(result2) >= 1 + assert any(rip.tensor_name == "relu1_out" for rip in result2) + + +class TestRegionOutputInsertionPointResolve(unittest.TestCase): + """Test RegionOutputInsertionPoint.resolve() method.""" + + def test_resolve_node_output(self): + """Test resolving a node output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + region.add_node(3) # MaxPool + region.outputs = ["pool_out"] + + # Output of last node (MaxPool) + ip = RegionOutputInsertionPoint(region_index=None, node_index=2, output_index=0) + + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "relu_out" for rip in result) + + def test_resolve_child_region_output(self): + """Test resolving a child region output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"relu_out": [3]} + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.outputs = ["relu_out"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + + ip = RegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + + result = ip.resolve(parent, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "relu_out" for rip in result) + + def test_resolve_residual_add_output(self): + """Test resolving Add output in residual block.""" + graph, tensors = _create_residual_graph() + graph.tensor_users_map = {"add_out": [4]} + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + region.add_node(3) # Add + region.add_node(4) # Relu2 + region.outputs = ["add_out"] + + # Add is at local index 3, output 0 + ip = RegionOutputInsertionPoint(region_index=None, node_index=3, output_index=0) + + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "add_out" for rip in result) + + +# ============================================================================= +# Collect From Region Tests +# ============================================================================= + + +class TestNodeInputInsertionPointCollectFrom(unittest.TestCase): + """Test NodeInputInsertionPoint.collect_from_region() method.""" + + def test_collect_valid_inputs(self): + """Test collecting valid node input insertion points from Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + region.add_node(3) # MaxPool + + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected some insertion points + assert len(result) >= 1 + # All should be NodeInputInsertionPoint + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + def test_collect_from_residual_block(self): + """Test collecting from residual block with skip connection.""" + graph, tensors = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + region.add_node(3) # Add + region.add_node(4) # Relu2 + + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected insertion points from Conv1, Add inputs, etc. + assert len(result) >= 1 + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + # Check that we have insertion points for different nodes + node_indices = {ip.node_index for ip in result} + assert len(node_indices) >= 1 # At least one node has valid inputs + + +class TestChildRegionInputInsertionPointCollectFrom(unittest.TestCase): + """Test ChildRegionInputInsertionPoint.collect_from_region() method.""" + + def test_collect_from_composite(self): + """Test collecting from COMPOSITE region with children.""" + graph, tensors = _create_simple_graph() + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.inputs = ["input"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + + # Should find the child's input + assert len(result) >= 0 # May be filtered by skip_invalid_insertion_points + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + def test_collect_from_leaf_returns_empty(self): + """Test that LEAF regions return empty list.""" + graph, _ = _create_simple_graph() + + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + leaf.add_node(0) + + result = ChildRegionInputInsertionPoint.collect_from_region(leaf, graph) + + assert len(result) == 0 + + def test_collect_from_composite_with_multiple_children(self): + """Test collecting from COMPOSITE with multiple child regions.""" + graph, tensors = _create_residual_graph() + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child1.inputs = ["input"] + child1.add_node(0) # Conv1 + child1.add_node(1) # Relu1 + + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + child2.inputs = ["relu1_out", "input"] # Two inputs including skip connection + child2.add_node(2) # Conv2 + child2.add_node(3) # Add + + parent.add_child(child1) + parent.add_child(child2) + + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + + # Should find inputs from both children + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + +class TestRegionOutputInsertionPointCollectFrom(unittest.TestCase): + """Test RegionOutputInsertionPoint.collect_from_region() method.""" + + def test_collect_node_outputs(self): + """Test collecting node output insertion points.""" + graph, tensors = _create_simple_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv + region.add_node(1) # BatchNorm + region.add_node(2) # Relu + region.add_node(3) # MaxPool + region.outputs = ["pool_out"] # Only pool_out is a region output + + result = RegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the node output that matches region output + assert len(result) >= 0 # May be filtered + assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) + + def test_collect_child_region_outputs(self): + """Test collecting child region output insertion points.""" + graph, tensors = _create_simple_graph() + + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child.outputs = ["relu_out"] + child.add_node(0) # Conv + child.add_node(1) # BatchNorm + child.add_node(2) # Relu + parent.add_child(child) + parent.outputs = ["relu_out"] # Child output is also parent output + + result = RegionOutputInsertionPoint.collect_from_region(parent, graph) + + # Should find the child region output + assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) + + def test_collect_residual_block_outputs(self): + """Test collecting outputs from residual block.""" + graph, tensors = _create_residual_graph() + + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + region.add_node(0) # Conv1 + region.add_node(1) # Relu1 + region.add_node(2) # Conv2 + region.add_node(3) # Add + region.add_node(4) # Relu2 + region.outputs = ["output"] # Final output + + result = RegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the output + assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py new file mode 100644 index 000000000..714f8a051 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the Region class in the autotuner. + +Tests region creation, hierarchy, and boundary management. +""" + +import unittest + +from modelopt.onnx.quantization.autotune.common import Region, RegionType + + +class TestRegion(unittest.TestCase): + """Test Region class functionality.""" + + def test_region_creation(self): + """Test creating regions of all types.""" + test_cases = [ + {"region_id": 1, "level": 0, "region_type": RegionType.LEAF}, + {"region_id": 2, "level": 1, "region_type": RegionType.COMPOSITE}, + {"region_id": 0, "level": 2, "region_type": RegionType.ROOT}, + ] + + for params in test_cases: + with self.subTest(**params): + region = Region(**params) + assert region.id == params["region_id"] + assert region.level == params["level"] + assert region.type == params["region_type"] + + def test_parent_child_relationship(self): + """Test parent-child relationships.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + + parent.add_child(child1) + parent.add_child(child2) + + assert len(parent.get_children()) == 2 + assert child1.parent == parent + assert child2.parent == parent + assert child1 in parent.get_children() + assert child2 in parent.get_children() + + def test_add_nodes(self): + """Test adding nodes to a region.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + region.add_node(0) + region.add_node(1) + region.add_node(2) + + assert len(region.nodes) == 3 + assert 0 in region.get_nodes() + assert 1 in region.get_nodes() + assert 2 in region.get_nodes() + + def test_input_output_tensors(self): + """Test setting input and output tensors.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + # Directly assign to inputs/outputs attributes + region.inputs = ["input_tensor_1", "input_tensor_2"] + region.outputs = ["output_tensor_1"] + + assert len(region.inputs) == 2 + assert len(region.outputs) == 1 + assert "input_tensor_1" in region.inputs + assert "output_tensor_1" in region.outputs + + def test_region_size_recursive(self): + """Test recursive size calculation.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + + # Add nodes to children + child1.add_node(0) + child1.add_node(1) + child2.add_node(2) + child2.add_node(3) + child2.add_node(4) + + # Add children to parent + parent.add_child(child1) + parent.add_child(child2) + + # Parent itself might have direct nodes + parent.add_node(5) + + # Recursive count should include all nodes + assert len(parent.get_region_nodes_and_descendants()) == 6 + + def test_metadata(self): + """Test region metadata storage.""" + region = Region(region_id=1, level=0, region_type=RegionType.LEAF) + + region.metadata["pattern"] = "Conv->Relu" + region.metadata["quantizable"] = "true" + + assert region.metadata["pattern"] == "Conv->Relu" + assert region.metadata["quantizable"] == "true" + + def test_region_type_checks(self): + """Test checking region types (LEAF and COMPOSITE).""" + leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) + composite = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + + assert leaf.type == RegionType.LEAF + assert leaf.type != RegionType.COMPOSITE + assert composite.type == RegionType.COMPOSITE + assert composite.type != RegionType.LEAF + + def test_hierarchical_structure(self): + """Test complex hierarchical structure.""" + root = Region(region_id=0, level=2, region_type=RegionType.ROOT) + composite1 = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + composite2 = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + leaf1 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + leaf2 = Region(region_id=4, level=0, region_type=RegionType.LEAF) + leaf3 = Region(region_id=5, level=0, region_type=RegionType.LEAF) + + # Build hierarchy + root.add_child(composite1) + root.add_child(composite2) + composite1.add_child(leaf1) + composite1.add_child(leaf2) + composite2.add_child(leaf3) + + # Add some nodes + leaf1.add_node(0) + leaf2.add_node(1) + leaf3.add_node(2) + + # Verify structure + assert len(root.get_children()) == 2 + assert len(composite1.get_children()) == 2 + assert len(composite2.get_children()) == 1 + assert len(root.get_region_nodes_and_descendants()) == 3 + + def test_remove_child(self): + """Test removing a child region.""" + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + + parent.add_child(child) + assert len(parent.get_children()) == 1 + + parent.remove_child(child) + assert len(parent.get_children()) == 0 + assert child.parent is None From 844f286125871c4ddc41a6236a606ef6d7401033 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Mon, 26 Jan 2026 05:28:22 +0000 Subject: [PATCH 2/5] Part-1 recent refactors Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/common.py | 501 ++-------- .../quantization/autotune/insertion_points.py | 888 +++++------------- .../autotune/test_insertion_points.py | 1 - .../onnx/quantization/autotune/test_region.py | 1 - 4 files changed, 299 insertions(+), 1092 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index 9ce18827a..cab73039a 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,53 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Common data structures and types for the QDQ Autotuner. - -This module provides the foundational classes used throughout the autotuner: - -**Exceptions:** -- Region-related: RegionError -- Autotuner-related: AutotunerError, AutotunerNotInitializedError, InvalidSchemeError - -**Region Hierarchy:** -- Region: Hierarchical subgraph representation with parent/child relationships -- RegionType: Enumeration for LEAF, COMPOSITE, and ROOT regions - -**Q/DQ Insertion Specifications:** -- InsertionScheme: Collection of insertion points with performance metrics - -**Scheme Management:** -- PatternSchemes: Multiple insertion schemes for a pattern (applies to all matching regions) -- PatternCache: Collection of top schemes for multiple patterns, used as autotuning seeds - -**Configuration:** -- Config: Autotuning parameters and Q/DQ default values -""" +"""Common data structures and types for the QDQ Autotuner.""" import hashlib -import logging from dataclasses import dataclass, field from enum import Enum -from typing import Any, Optional +from typing import Any import onnx_graphsurgeon as gs +from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, NodeInputInsertionPoint, RegionOutputInsertionPoint, ) -# Module logger -logger = logging.getLogger(__name__) - - -# Region-related Exceptions -class RegionError(Exception): - """Base exception for region-related errors.""" - -# Autotuner-related Exceptions class AutotunerError(Exception): """Base exception for autotuner-related errors.""" @@ -86,29 +56,12 @@ class RegionType(Enum): class Region: - """Hierarchical subgraph region in an ONNX computation graph. + """A subgraph region in an ONNX graph, used as the unit for Q/DQ insertion. - A Region represents a cohesive subgraph with well-defined boundaries, supporting: - - **Hierarchical Structure:** - - Parent/child relationships forming a multi-level hierarchy - - LEAF regions contain only direct nodes - - COMPOSITE regions contain child regions (and optionally direct nodes) - - ROOT regions encompass the entire graph - - **Node Management:** - - Direct nodes: Operations directly in this region (not in children) - - Recursive nodes: All operations including those in descendant regions - - **Boundary Tracking:** - - Input tensors: Data entering the region from outside - - Output tensors: Data leaving the region to outside consumers - - **Pattern Matching:** - - Regions with identical structure share the same pattern signature - - Pattern-based optimization applies schemes to all matching regions - - Regions are the fundamental unit for Q/DQ insertion and optimization. + Regions form a hierarchy: ROOT contains the entire graph, COMPOSITE regions + contain child regions, and LEAF regions contain only nodes. Each region tracks + its direct nodes, input/output tensors, and a pattern signature for matching + regions with identical structure. """ def __init__(self, region_id: int, level: int, region_type: RegionType): @@ -127,119 +80,64 @@ def __init__(self, region_id: int, level: int, region_type: RegionType): self.nodes: set[int] = set() self.inputs: list[str] = [] self.outputs: list[str] = [] + self.metadata: dict[str, str] = {} - # ========================================================================= - # Basic Accessors - # ========================================================================= - - def get_id(self) -> int: - """Get region ID.""" - return self.id - - def set_id(self, region_id: int) -> None: - """Set region ID (for RegionBuilder use).""" - self.id = region_id - - def get_level(self) -> int: - """Get region level in hierarchy.""" - return self.level - - def set_level(self, level: int) -> None: - """Set region level in hierarchy (for RegionBuilder use).""" - self.level = level - - def get_type(self) -> RegionType: - """Get region type.""" - return self.type - - def set_type(self, region_type: RegionType) -> None: - """Set region type (for RegionBuilder use).""" - self.type = region_type - - # ========================================================================= - # Hierarchy Management - # ========================================================================= - - def get_parent(self) -> Optional["Region"]: - """Get parent region.""" - return self.parent - - def set_parent(self, parent: Optional["Region"]) -> None: - """Set parent region.""" - self.parent = parent - - def get_children(self) -> list["Region"]: + def get_children(self, *, sort: bool = False) -> list["Region"]: """Get all child regions.""" + if sort: + return sorted( + self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants()) + ) return self.children def remove_child(self, child: "Region") -> bool: - """Remove a child region from this region's children list. - - Args: - child: The child region to remove - - Returns: - True if child was found and removed, False otherwise - """ - child_id = child.get_id() - initial_count = len(self.children) - self.children = [c for c in self.children if c.get_id() != child_id] - removed = len(self.children) < initial_count - - if removed and child.parent and child.parent.get_id() == self.id: - child.set_parent(None) - - return removed + """Remove a child region from this region's children list.""" + if child not in self.children: + return False + self.children.remove(child) + if child.parent and child.parent.id == self.id: + child.parent = None + return True def add_child(self, child: "Region") -> None: """Add a child sub-region.""" - # Prevent adding self as child - if child.get_id() == self.id: + if child.id == self.id: logger.warning(f"Cannot add region {self.id} as its own child") return - # Prevent creating cycles: check if self is already a descendant of child - if self._is_descendant_of(child): + if self.is_descendant_of(child): logger.warning( - f"Cycle detected: region {self.id} is already a descendant of region {child.get_id()}" + f"Cycle detected: region {self.id} is already a descendant of region {child.id}" ) return - # Check if child already has a different parent - if child.parent is not None and child.parent.get_id() != self.id: - old_parent_id = child.parent.get_id() + if child.parent is not None and child.parent.id != self.id: + old_parent_id = child.parent.id logger.debug( - f"Re-parenting region {child.get_id()}: moving from parent {old_parent_id} to {self.id}" + f"Re-parenting region {child.id}: moving from parent {old_parent_id} to {self.id}" ) - # Remove from old parent to maintain tree structure child.parent.remove_child(child) - # Check if child is already in children list - if any(c.get_id() == child.get_id() for c in self.children): - logger.debug(f"Region {child.get_id()} already child of {self.id}") + if any(c.id == child.id for c in self.children): + logger.debug(f"Region {child.id} already child of {self.id}") return self.children.append(child) - child.set_parent(self) + child.parent = self - def _is_descendant_of(self, potential_ancestor: "Region") -> bool: + def is_descendant_of(self, potential_ancestor: "Region") -> bool: """Check if this region is a descendant of potential_ancestor.""" visited = set() current = self.parent while current: - if current.get_id() in visited: - # Already visited, there's a cycle in parents + if current.id in visited: return False - visited.add(current.get_id()) - if current.get_id() == potential_ancestor.get_id(): + visited.add(current.id) + if current.id == potential_ancestor.id: return True current = current.parent return False - # ========================================================================= - # Node Management - # ========================================================================= - def add_node(self, node_index: int) -> None: """Add a node index to this region.""" self.nodes.add(node_index) @@ -248,65 +146,33 @@ def add_nodes(self, node_indices: list[int]) -> None: """Add multiple node indices to this region.""" self.nodes.update(node_indices) - def get_nodes(self) -> set[int]: - """Get direct node indices in this region only. - - Returns only nodes directly owned by this region, excluding nodes - in child regions. Use get_all_nodes_recursive() for complete coverage. - - Returns: - Set of node indices (absolute positions in the graph) - """ - return self.nodes - - def get_all_nodes_recursive(self, _visited: set[int] | None = None) -> set[int]: - """Get all node indices recursively, including descendants. - - Traverses the entire subtree rooted at this region, collecting nodes - from this region and all child regions recursively. - - Args: - _visited: Internal parameter for cycle detection (do not use) + def get_nodes(self, *, sort: bool = False) -> list[int]: + """Get direct node indices in this region only.""" + if sort: + return sorted(self.nodes) + return list(self.nodes) - Returns: - Set of all node indices in this region and its descendants - """ + def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants.""" if _visited is None: _visited = set() # Detect cycles - if self.id in _visited: - logger.warning(f"Cycle detected in region {self.id} during node traversal") - return set() + assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal" _visited.add(self.id) all_nodes = set(self.nodes) for child in self.children: - all_nodes.update(child.get_all_nodes_recursive(_visited)) + all_nodes.update(child.get_region_nodes_and_descendants(_visited)) return all_nodes def contains_node(self, node_index: int) -> bool: """Check if region contains a specific node (direct only).""" return node_index in self.nodes - def contains_node_recursive(self, node_index: int, _visited: set[int] | None = None) -> bool: + def contains_node_within_region_and_descendants(self, node_index: int) -> bool: """Check if region contains a node recursively.""" - if _visited is None: - _visited = set() - - # Detect cycles - if self.id in _visited: - return False - - _visited.add(self.id) - - if self.contains_node(node_index): - return True - return any(child.contains_node_recursive(node_index, _visited) for child in self.children) - - # ========================================================================= - # Input/Output Management - # ========================================================================= + return node_index in self.get_region_nodes_and_descendants() def add_input(self, tensor_name: str) -> None: """Add an input tensor name.""" @@ -318,80 +184,31 @@ def add_output(self, tensor_name: str) -> None: if tensor_name not in self.outputs: self.outputs.append(tensor_name) - def get_inputs(self) -> list[str]: - """Get region input tensors.""" - return self.inputs - - def get_outputs(self) -> list[str]: - """Get region output tensors.""" - return self.outputs - - # ========================================================================= - # Size and Query Methods - # ========================================================================= - - def get_size(self) -> int: - """Get the number of direct nodes in this region. - - Returns: - Count of nodes directly in this region (excludes child regions) - """ - return len(self.nodes) - - def get_total_size(self, _visited: set[int] | None = None) -> int: - """Get total node count recursively including all descendants. - - Computes the sum of nodes in this region and all child regions, - providing the total footprint of the region subtree. - - Args: - _visited: Internal parameter for cycle detection (do not use) - - Returns: - Total number of nodes in this region and all descendants - """ + def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants.""" if _visited is None: _visited = set() # Detect cycles - if self.id in _visited: - logger.warning(f"Cycle detected in region {self.id} during size calculation") - return len(self.nodes) + assert self.id not in _visited, ( + f"Cycle detected in region {self.id} during size calculation" + ) _visited.add(self.id) total = len(self.nodes) for child in self.children: - total += child.get_total_size(_visited) + total += child.get_size_of_region_and_descendants(_visited) return total - # ========================================================================= - # Region Operations - # ========================================================================= - def merge(self, other: "Region") -> None: - """Merge another region into this one. - - Combines the nodes and children from the other region into this region. - The other region's children become children of this region, updating - their parent references accordingly. - - Args: - other: Region to merge into this one - """ + """Merge another region into this one.""" if not other: return - # Merge direct nodes self.nodes.update(other.nodes) - # Merge children (updates their parent references) for child in other.children: self.add_child(child) - # ========================================================================= - # String Representation - # ========================================================================= - - def to_string(self) -> str: - """Print region information for debugging.""" + def __repr__(self) -> str: type_str = self.type.value return ( f"Region[id={self.id}, level={self.level}, type={type_str}, " @@ -399,12 +216,6 @@ def to_string(self) -> str: f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" ) - def __str__(self) -> str: - return self.to_string() - - def __repr__(self) -> str: - return self.to_string() - def compute_structural_signature(self, graph: gs.Graph) -> str: """Compute deterministic structural signature for pattern matching. @@ -426,42 +237,9 @@ def compute_structural_signature(self, graph: gs.Graph) -> str: raise NotImplementedError("Not implemented") -# ============================================================================= -# Autotuner Q/DQ Insertion Specifications -# ============================================================================= - - @dataclass class InsertionScheme: - """Complete Q/DQ insertion specification for a region pattern. - - An InsertionScheme defines a complete Q/DQ configuration for a pattern, - combining both node-level and region-level insertion points. The scheme - is applied to all regions matching the pattern. - - **Scheme Identity:** - - Uniquely identified by the combination of insertion points (computed hash) - - latency_ms is a measured performance metric, not part of identity - - Two schemes with same insertion points but different latencies are considered identical - - **Application:** - - Node insertion points: Q/DQ at node inputs within the pattern - - Region insertion points: Q/DQ at child region boundaries (COMPOSITE only) - - All are resolved to actual configurations for each matching region - - **Performance Tracking:** - - latency_ms: Measured performance (inf = not yet measured) - - error: Whether this scheme encountered an error during measurement - - Used to select the best scheme for each pattern - - **Attributes:** - node_inputs: Q/DQ insertions at node inputs (list of NodeInputInsertionPoint) - child_region_inputs: Q/DQ insertions at child boundaries (list of ChildRegionInputInsertionPoint) - region_outputs: Q/DQ insertions at region outputs (list of RegionOutputInsertionPoint) - latency_ms: Measured latency in milliseconds (inf if not measured) - error: True if scheme measurement failed, False otherwise - profile_timestamp: ISO format timestamp when this scheme was profiled (None if not yet profiled) - """ + """Q/DQ insertion specification applied to all regions matching a pattern.""" node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) @@ -472,27 +250,7 @@ class InsertionScheme: @property def hash(self) -> str: - """Compute deterministic hash for scheme identity. - - The hash uniquely identifies this scheme configuration based on its - insertion points. Two schemes with identical insertion points produce - the same hash, regardless of their measured latencies. - - **Hash Input:** - - Sorted node_inputs (for deterministic ordering) - - Sorted child_region_inputs (for deterministic ordering) - - Sorted region_outputs (for deterministic ordering) - - latency_ms is EXCLUDED (performance metric, not identity) - - **Use Cases:** - - Detect duplicate schemes before measurement - - Group schemes by configuration - - Efficient scheme comparison - - Returns: - 32-character hexadecimal string (SHA-256 truncated to 128 bits) - """ - # Sort points for deterministic hashing + """Compute deterministic hash for scheme identity.""" sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) sorted_regions = sorted( [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] @@ -501,79 +259,20 @@ def hash(self) -> str: [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] ) - # Create hash input string hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" - # Compute SHA-256 hash (128 bits) return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] @property def is_empty(self) -> bool: - """Check if this is a baseline scheme with no Q/DQ insertions. - - Returns: - True if scheme has no node/region insertion points - """ - return ( - len(self.node_inputs) == 0 - and len(self.child_region_inputs) == 0 - and len(self.region_outputs) == 0 - ) - - @property - def has_error(self) -> bool: - """Check if this scheme encountered an error during measurement. - - Returns: - True if scheme has error=True, False otherwise - """ - return self.error + """Check if this is a baseline scheme with no Q/DQ insertions.""" + return not self.node_inputs and not self.child_region_inputs and not self.region_outputs @property def is_profiled(self) -> bool: - """Check if this scheme has been profiled (measured). - - A scheme is considered profiled if it has been measured (has non-infinite latency) - or has encountered an error during measurement. - - Returns: - True if scheme has been measured (latency_ms != inf) or has error, - False if scheme is waiting to be profiled (error=False and latency_ms=inf) - """ + """Check if this scheme has been profiled (measured).""" return self.error or self.latency_ms != float("inf") - @property - def num_node_insertions(self) -> int: - """Get count of node-level Q/DQ insertion points. - - Returns: - Number of NodeInputInsertionPoint entries - """ - return len(self.node_inputs) - - @property - def num_region_insertions(self) -> int: - """Get count of region-level Q/DQ insertion points. - - These specify Q/DQ insertions at child region boundaries within - COMPOSITE regions. - - Returns: - Number of ChildRegionInputInsertionPoint entries - """ - return len(self.child_region_inputs) - - @property - def num_region_output_insertions(self) -> int: - """Get count of region output insertion points. - - These specify Q/DQ insertions at outputs from child regions or nodes. - - Returns: - Number of RegionOutputInsertionPoint entries - """ - return len(self.region_outputs) - def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" return { @@ -588,19 +287,7 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": - """Create InsertionScheme from serialized dictionary. - - Reconstructs the insertion scheme from saved data, including node and - region insertion points. The hash is automatically recomputed from all - components to ensure consistency. - - Args: - data: Dictionary containing 'latency_ms', 'nodes_insertion_points', - 'child_region_inputs', and 'region_outputs' keys - - Returns: - Reconstructed InsertionScheme instance - """ + """Create InsertionScheme from serialized dictionary.""" scheme = cls() scheme.latency_ms = data.get("latency_ms", float("inf")) scheme.error = data.get("error", False) @@ -617,72 +304,22 @@ def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) ] - # Note: hash is computed from points, so we don't load it from dict - # This ensures consistency even if stored hash differs - return scheme def distance(self, other: "InsertionScheme") -> int: - """Compute edit distance between this scheme and another scheme. - - The edit distance is the minimum number of add/remove operations needed - to transform this scheme into the other scheme. This is computed as the - symmetric difference between the insertion point sets. - - **Distance Calculation:** - - Counts insertion points in self but not in other (need to be removed) - - Counts insertion points in other but not in self (need to be added) - - Considers all three types of insertion points: - * node_inputs - * child_region_inputs - * region_outputs - - Args: - other: InsertionScheme to compare against - - Returns: - Total edit distance (number of add + remove operations) - - Example: - >>> scheme1 = InsertionScheme( - ... node_inputs=[ - ... NodeInputInsertionPoint(0, 0), - ... NodeInputInsertionPoint(1, 0), - ... ] - ... ) - >>> scheme2 = InsertionScheme( - ... node_inputs=[ - ... NodeInputInsertionPoint(0, 0), - ... NodeInputInsertionPoint(2, 0), - ... ] - ... ) - >>> scheme1.distance(scheme2) # 2 (remove (1,0), add (2,0)) - 2 - """ - # Convert insertion points to sets for efficient set operations - self_nodes = set(self.node_inputs) - other_nodes = set(other.node_inputs) - - self_regions = set(self.child_region_inputs) - other_regions = set(other.child_region_inputs) - - self_region_outputs = set(self.region_outputs) - other_region_outputs = set(other.region_outputs) - - # Compute symmetric difference (elements in either set but not both) - # This gives us the total number of add + remove operations - node_distance = len(self_nodes.symmetric_difference(other_nodes)) - region_distance = len(self_regions.symmetric_difference(other_regions)) - region_output_distance = len(self_region_outputs.symmetric_difference(other_region_outputs)) - - return node_distance + region_distance + region_output_distance + """Compute edit distance between this scheme and another scheme.""" + return ( + len(set(self.node_inputs).symmetric_difference(other.node_inputs)) + + len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs)) + + len(set(self.region_outputs).symmetric_difference(other.region_outputs)) + ) def __str__(self) -> str: """String representation for debugging.""" error_str = ", error=True" if self.error else "" return ( - f"InsertionScheme(node_insertions={self.num_node_insertions}, " - f"region_insertions={self.num_region_insertions}, " - f"region_output_insertions={self.num_region_output_insertions}, " + f"InsertionScheme(node_insertions={len(self.node_inputs)}, " + f"region_insertions={len(self.child_region_inputs)}, " + f"region_output_insertions={len(self.region_outputs)}, " f"latency={self.latency_ms:.3f}ms{error_str})" ) diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py index 32722e44f..0be87b781 100644 --- a/modelopt/onnx/quantization/autotune/insertion_points.py +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,46 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Q/DQ Insertion Point Management for ONNX Quantization. - -This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) -insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based -Q/DQ insertion that can be reused across multiple matching regions in a model. - -Core Concepts: --------------- -1. **Pattern-Relative Insertion Points**: Insertion points are defined relative to region patterns - rather than absolute node IDs, enabling scheme reuse across all matching regions. - -2. **Resolution Process**: Pattern-relative indices are resolved to actual tensor names for each - specific region instance, then Q/DQ pairs are inserted at the resolved locations. - -3. **Hierarchical Support**: Supports Q/DQ insertion at multiple levels: - - Node inputs within regions - - Child region boundaries (inputs/outputs) - - Region outputs - -Classes: --------- -- ResolvedInsertionPoint: Resolved Q/DQ insertion point with actual tensor name -- NodeInputInsertionPoint: Pattern-relative insertion point at node inputs -- ChildRegionInputInsertionPoint: Pattern-relative insertion point at child region inputs -- RegionOutputInsertionPoint: Pattern-relative insertion point at region/node outputs - -Utilities: ----------- -- skip_invalid_insertion_points(): Filter out non-quantizable tensors -- has_quantizable_operations(): Check if region contains major quantizable ops -- resolve_region_io_insertion_points(): Resolve region I/O to actual insertion points -- merge_resolved_insertion_points(): Merge insertion points when all users are quantized - -Constants: ----------- -- BOOL_OPERATIONS: Boolean/comparison operations (not quantizable) -- SHAPE_OPERATIONS: Shape manipulation operations (not quantizable) -- MAJOR_QUANTIZABLE_OPERATIONS: Key operations that benefit from quantization -""" +"""Q/DQ insertion point management for ONNX quantization autotune.""" +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import TYPE_CHECKING, Any @@ -62,115 +25,53 @@ if TYPE_CHECKING: from modelopt.onnx.quantization.autotune.common import Region +from modelopt.onnx.op_types import ( + get_aggregation_ops, + get_bitwise_ops, + get_bool_ops, + get_comparison_ops, + get_conditional_ops, + get_copy_ops, + get_set_ops, + get_value_check_ops, + is_fusible_reduction_op, +) from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices -BOOL_OPERATIONS = { - "Not", - "And", - "Or", - "Xor", - "BitwiseAnd", - "BitwiseOr", - "BitwiseXor", - "BitShift", - "IsNaN", - "IsInf", - "Sign", - "Abs", - "Equal", - "Greater", - "GreaterOrEqual", - "Less", - "LessOrEqual", - "Where", - "Max", - "Min", - "Mean", - "Median", - "ArgMax", - "ArgMin", - "ReduceMax", - "ReduceMin", - "ReduceSum", - "ReduceMean", - "All", - "Any", - "Unique", - "NonZero", - "TopK", -} - -SHAPE_OPERATIONS = { - "Cast", - "Ceil", - "Clip", - "Compress", - "Concat", - "ExpandDims", - "Flatten", - "Gather", - "GatherElements", - "GatherND", - "Identity", - "Pad", - "Range", - "Scatter", - "ScatterND", - "Shape", - "Slice", - "Split", - "Squeeze", - "Tile", - "Transpose", - "Unsqueeze", - "View", -} - -MAJOR_QUANTIZABLE_OPERATIONS = { - "Conv", - "ConvTranspose", - "Gemm", - "MatMul", - "AveragePool", - "MaxPool", - "GlobalAveragePool", - "GlobalMaxPool", - "Resize", - "Add", - "Sum", - "Mul", - "Relu", -} +class InsertionPoint(ABC): + """Abstract base class for pattern-relative Q/DQ insertion points.""" -@dataclass(frozen=True) -class ResolvedInsertionPoint: - """Resolved Q/DQ insertion point with actual tensor name and optional node context. + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + ... + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionPoint": + """Create from dictionary.""" + ... - After resolving pattern-relative insertion points, this class represents the - actual location where Q/DQ pairs should be inserted in the graph. + @abstractmethod + def resolve(self, region: "Region", graph: gs.Graph) -> set["ResolvedInsertionPoint"]: + """Resolve pattern-relative insertion point to actual tensor names.""" + ... - **Insertion Modes:** - 1. Node-specific insertion (node_index and input_index are set): - - Inserts Q/DQ at a specific input of a specific node - - More precise control over where quantization happens - 2. Tensor-level insertion (node_index and input_index are None): - - Inserts Q/DQ for all users of the tensor - - Used when all consumers of a tensor should be quantized together + @staticmethod + @abstractmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["InsertionPoint"]: + """Collect all valid insertion points of this type from a region.""" + ... - **Attributes:** - - tensor_name: Name of the tensor where Q/DQ should be inserted - - node_index: Absolute graph node index (not pattern-relative), or None for tensor-level insertion - - input_index: Input tensor index of that node, or None for tensor-level insertion - This class is immutable (frozen) to allow safe use in sets and as dict keys. - """ +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context.""" tensor_name: str - # Absolute graph node index (or None for tensor-level insertion) - node_index: int | None = None - # Input tensor index of that node (or None) - input_index: int | None = None + node_index: int | None = None # Absolute graph node index (or None for tensor-level insertion) + input_index: int | None = None # Input tensor index of that node (or None) def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" @@ -189,43 +90,13 @@ def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": input_index=data.get("input_index"), ) - def __str__(self) -> str: - """String representation for debugging.""" - return ( - f"ResolvedInsertionPoint(tensor_name={self.tensor_name}, " - f"node={self.node_index}, input={self.input_index})" - ) - @dataclass(frozen=True) -class NodeInputInsertionPoint: - """Pattern-relative Q/DQ insertion point at a node's input. - - Specifies where to insert a Q/DQ pair within a region pattern using - pattern-relative indices rather than absolute node IDs. This enables - insertion scheme reuse across all regions matching the same pattern. - - **Resolution Process:** - 1. Pattern-relative indices (node_index, input_index) are defined once - 2. For each matching region, indices are resolved to actual tensor names - 3. Q/DQ pairs are inserted at the resolved tensor locations +class NodeInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a node's input (frozen/hashable).""" - **Example:** - - NodeInputInsertionPoint(node_index=0, input_index=1) - - Resolves to: the second input (index 1) of the first node (index 0) in the pattern - - Actual tensor name depends on the specific region instance - - **Attributes:** - - node_index: Index of the node within the pattern's sorted node list (0-based) - - input_index: Index of the input tensor for that node (0-based) - - This class is immutable (frozen) to allow safe use in sets and as dict keys. - """ - - # Pattern-relative node index - node_index: int - # Input tensor index of that node - input_index: int + node_index: int # Pattern-relative node index + input_index: int # Input tensor index of that node def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" @@ -236,138 +107,62 @@ def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": """Create from dictionary.""" return cls(node_index=data["node_index"], input_index=data["input_index"]) - def __str__(self) -> str: - """String representation for debugging.""" - return f"NodeInputInsertionPoint(node={self.node_index}, input={self.input_index})" - def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: - """Resolve a node input insertion point to actual tensor names for a matching region. - - Converts pattern-relative node/input indices to absolute node indices and actual - tensor names in the graph. Special handling for Conv/ConvTranspose operations - automatically includes weight quantization when input is quantized. - - Args: - region: The region instance matching this pattern - graph: The ONNX graph containing the nodes - - Returns: - Set of ResolvedInsertionPoint objects with actual tensor names - """ - nodes_list = list(graph.nodes) - node_indices = sorted(region.get_nodes()) - resolved_ips = set() - - # Map from pattern-relative node index to absolute graph node index + """Resolve a node input insertion point to actual tensor names for a matching region.""" + node_indices = region.get_nodes(sort=True) assert self.node_index < len(node_indices), "Node index out of range" actual_node_idx = node_indices[self.node_index] - assert actual_node_idx < len(nodes_list), "Node index out of range" - node = nodes_list[actual_node_idx] + node = graph.nodes[actual_node_idx] assert self.input_index < len(node.inputs), "Input index out of range" - # Resolve the input tensor name using input_index - inp = node.inputs[self.input_index] - if hasattr(inp, "name") and inp.name: - ip = ResolvedInsertionPoint( - tensor_name=inp.name, node_index=actual_node_idx, input_index=self.input_index - ) - resolved_ips.add(ip) - + resolved_ips = set() + # Determine which input indices to resolve (include weights for Conv/ConvTranspose) + input_indices = [self.input_index] if node.op in ["Conv", "ConvTranspose"]: assert self.input_index == 0, ( - "Conv and ConvTranspose inputs and weights should be quantized at same time" + "Conv/ConvTranspose inputs and weights must be quantized together" ) - assert len(node.inputs) >= 2, "Conv and ConvTranspose should have at least 2 inputs" - inp = node.inputs[1] + assert len(node.inputs) >= 2, "Conv/ConvTranspose should have at least 2 inputs" + input_indices.append(1) + + for idx in input_indices: + inp = node.inputs[idx] if hasattr(inp, "name") and inp.name: - ip = ResolvedInsertionPoint( - tensor_name=inp.name, node_index=actual_node_idx, input_index=1 + resolved_ips.add( + ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=idx + ) ) - resolved_ips.add(ip) - return resolved_ips @staticmethod def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: - """Collect all valid node input insertion points from a region. - - Analyzes each node in the region and identifies all valid input tensors - where Q/DQ pairs could be inserted. Filters out invalid insertion points - using skip_invalid_insertion_points(). - - Args: - region: The region to collect insertion points from - graph: The ONNX graph containing the nodes - - Returns: - List of NodeInputInsertionPoint objects representing valid insertion locations - """ - nodes_list = list(graph.nodes) - node_indices = sorted(region.get_nodes()) - - node_input_insertion_points = [] + """Collect all valid node input insertion points from a region.""" + node_indices = region.get_nodes(sort=True) + insertion_points = [] for local_idx, node_idx in enumerate(node_indices): - assert node_idx < len(nodes_list), "Node index out of range" - node = nodes_list[node_idx] - # Analyze each input of the node + node = graph.nodes[node_idx] for input_idx, inp in enumerate(node.inputs): - # Skip if tensor doesn't have a valid name if not (hasattr(inp, "name") and inp.name): continue - # Skip if insertion point is invalid (wrong dtype, small size, special input, etc.) if skip_invalid_insertion_points(graph, inp.name, node): continue - # Create insertion point for valid tensor - ip = NodeInputInsertionPoint( - # Pattern-relative node index - node_index=local_idx, - input_index=input_idx, + insertion_points.append( + NodeInputInsertionPoint(node_index=local_idx, input_index=input_idx) ) - node_input_insertion_points.append(ip) - - return node_input_insertion_points + return insertion_points @dataclass(frozen=True) -class ChildRegionInputInsertionPoint: - """Pattern-relative Q/DQ insertion point at a child region's input boundary. - - Specifies where to insert Q/DQ pairs at the input boundaries of child regions - within COMPOSITE regions. This allows parent regions to control quantization - at child boundaries, potentially overriding or complementing child region - optimizations. - - **Use Case:** - Parent regions can insert Q/DQ pairs at child region inputs to: - - Add quantization at child boundaries even if the child has no internal Q/DQ - - Override or supplement the child's own boundary Q/DQ decisions - - Apply different quantization schemes based on the parent context - - **Resolution Process:** - 1. Pattern-relative indices (region_index, input_index) are defined once - 2. For each matching parent region, indices resolve to actual child boundaries: - - region_index identifies which child region (in parent's sorted child list) - - input_index identifies which input tensor of that child region - 3. Q/DQ pairs are inserted at the resolved child input tensor locations - - **Example:** - - ChildRegionInputInsertionPoint(region_index=0, input_index=1) - - Resolves to: the second input tensor (index 1) of the first child region (index 0) - - Actual tensor name depends on the specific parent/child region instances - - **Note:** Only applies to COMPOSITE regions. LEAF regions have no children, - so child region insertion points have no effect there. - - **Attributes:** - - region_index: Index of the child region within the parent pattern's sorted child list (0-based) - - input_index: Index of the input tensor for that child region (0-based) - - This class is immutable (frozen) to allow safe use in sets and as dict keys. +class ChildRegionInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region's input boundary (frozen/hashable). + + Only applies to COMPOSITE regions; LEAF regions have no children. """ - # Index of the child region within the parent pattern's sorted child list (0-based) + # Pattern-relative child region index region_index: int - # Index of the input tensor for that child region (0-based) + # Input tensor index of that child region input_index: int def to_dict(self) -> dict[str, Any]: @@ -376,160 +171,51 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": - """Create from dictionary. - - Backward compatible: Ignores obsolete fields like 'child_region_id' - from older serialization formats. - - Args: - data: Dictionary with 'region_index' and 'input_index' keys - - Returns: - ChildRegionInputInsertionPoint instance - """ - # Ignore child_region_id if present in old data + """Create from dictionary.""" return cls(region_index=data["region_index"], input_index=data["input_index"]) - def __str__(self) -> str: - """String representation for debugging.""" - return ( - f"ChildRegionInputInsertionPoint(region={self.region_index}, input={self.input_index})" - ) - def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: - """Resolve a child region input insertion point to actual tensor names for a matching region. - - Converts pattern-relative child region index and input index to the actual tensor - name at that child region's input boundary, then resolves to all node inputs that - consume that tensor. - - Args: - region: The parent region instance matching this pattern - graph: The ONNX graph containing the nodes - - Returns: - Set of ResolvedInsertionPoint objects with actual tensor names. - Returns empty set for LEAF regions (no children). - """ + """Resolve a child region input insertion point to actual tensor names.""" from modelopt.onnx.quantization.autotune.common import RegionType - if graph is None: - raise ValueError("graph parameter is required") - - # LEAF regions have no child boundaries - if region.get_type() == RegionType.LEAF: + if region.type == RegionType.LEAF: return set() - # Get sorted child regions (must match order in RegionPattern._compute_signature_recursive) - children_regions = region.get_children() - children_regions = sorted( - children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) - ) - # Map from pattern-relative child index to actual child region - resolved_ips = set() + children_regions = region.get_children(sort=True) assert self.region_index < len(children_regions), "Child region index out of range" child_region = children_regions[self.region_index] - assert self.input_index < len(child_region.get_inputs()), "Input index out of range" - # Resolve the input tensor name using input_index - tensor_name = child_region.get_inputs()[self.input_index] - assert tensor_name is not None, "Tensor name is required" - resolved_ips.update(resolve_region_io_insertion_points(child_region, graph, tensor_name)) - - return resolved_ips + assert self.input_index < len(child_region.inputs), "Input index out of range" + tensor_name = child_region.inputs[self.input_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) @staticmethod def collect_from_region( region: "Region", graph: gs.Graph ) -> list["ChildRegionInputInsertionPoint"]: - """Collect all valid child region input insertion points from a region. - - For COMPOSITE regions, analyzes each child region and identifies all valid - input tensors where Q/DQ pairs could be inserted at child boundaries. - Returns empty list for LEAF regions (no children). - - Args: - region: The parent region to collect insertion points from - graph: The ONNX graph containing the nodes - - Returns: - List of ChildRegionInputInsertionPoint objects representing valid insertion locations - """ + """Collect all valid child region input insertion points from a region.""" from modelopt.onnx.quantization.autotune.common import RegionType - child_region_input_insertion_points = [] - - # Only COMPOSITE regions have child boundaries for Q/DQ insertion - if region.get_type() != RegionType.LEAF: - # Get all child regions, sorted for deterministic ordering - # Must match sorting in _compute_signature_recursive to ensure - # insertion point indices align with pattern structure - children_regions = region.get_children() - children_regions = sorted( - children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) - ) - - for local_idx, child_region in enumerate(children_regions): - # Create insertion point for each input tensor of the child region - for input_idx, inp in enumerate(child_region.get_inputs()): - if skip_invalid_insertion_points(graph, inp, child_region): - continue - point = ChildRegionInputInsertionPoint( - # Child region index within parent pattern - region_index=local_idx, - # Input index within child region - input_index=input_idx, - ) - child_region_input_insertion_points.append(point) + if region.type == RegionType.LEAF: + return [] - return child_region_input_insertion_points + insertion_points = [] + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for input_idx, inp in enumerate(child_region.inputs): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + insertion_points.append( + ChildRegionInputInsertionPoint(region_index=local_idx, input_index=input_idx) + ) + return insertion_points @dataclass(frozen=True) -class RegionOutputInsertionPoint: - """Pattern-relative Q/DQ insertion point at an output location. - - Specifies where to insert Q/DQ pairs at output boundaries. This can be either: - 1. Output from a child region (in COMPOSITE regions) - 2. Output from a node within the region - - **Use Case:** - Parent regions can: - - Add Q/DQ at child region output boundaries - - Add Q/DQ at node outputs within the region - - Control quantization precision as data flows through the region hierarchy - - **Resolution Process:** - 1. Pattern-relative indices are defined once - 2. If output is from a child region: use region_index (node_index is None) - - region_index identifies which child region (in sorted order) - - output_index identifies which output tensor of that child region - 3. If output is from a node: use node_index (region_index is None) - - node_index identifies which node (in sorted order) - - output_index identifies which output tensor of that node - 4. Resolves to the actual tensor name at that output location - - **Examples:** - - RegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) - → First output of the first child region - - RegionOutputInsertionPoint(region_index=None, node_index=2, output_index=1) - → Second output of the third node - - **Note:** Exactly one of region_index or node_index must be set (the other must be None). - - **Attributes:** - - region_index: Index of child region within parent pattern (0-based), or None - - node_index: Index of node within the region (0-based), or None - - output_index: Index of the output tensor (0-based) - - This class is immutable (frozen) to allow safe use in sets and as dict keys. - """ +class RegionOutputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region or node output (frozen/hashable).""" - # Index of child region within parent pattern (0-based), or None - region_index: int | None - # Index of node within the region (0-based), or None - node_index: int | None - # Index of the output tensor (0-based) - output_index: int + region_index: int | None # Pattern-relative child region index (or None) + node_index: int | None # Pattern-relative node index (or None) + output_index: int # Output tensor index def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" @@ -541,183 +227,83 @@ def to_dict(self) -> dict[str, Any]: @classmethod def from_dict(cls, data: dict[str, Any]) -> "RegionOutputInsertionPoint": - """Create from dictionary. - - Args: - data: Dictionary with 'region_index', 'node_index', and 'output_index' keys - - Returns: - RegionOutputInsertionPoint instance - """ + """Create from dictionary.""" return cls( region_index=data.get("region_index"), node_index=data.get("node_index"), output_index=data["output_index"], ) - def __str__(self) -> str: - """String representation for debugging.""" - if self.region_index is not None: - return f"RegionOutputInsertionPoint(region={self.region_index}, output={self.output_index})" - else: - return f"RegionOutputInsertionPoint(node={self.node_index}, output={self.output_index})" - def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: - """Resolve a region output insertion point to actual tensor names for a matching region. - - Converts pattern-relative indices to the actual tensor name at an output location: - - If region_index is set: Resolves to a child region's output tensor - - If node_index is set: Resolves to a node's output tensor - - Then identifies all node inputs that consume that output tensor. - - Args: - region: The region instance matching this pattern - graph: The ONNX graph containing the nodes - - Returns: - Set of ResolvedInsertionPoint objects with actual tensor names - """ - if graph is None: - raise ValueError("graph parameter is required") - - # Get sorted nodes for node output resolution - nodes_list = list(graph.nodes) - node_indices = sorted(region.get_nodes()) - children_regions = region.get_children() - children_regions = sorted( - children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) - ) - - # Resolve each region output insertion point from the scheme to actual tensor names - resolved_ips = set() - # Handle child region outputs (region_index is set) + """Resolve a region output insertion point to actual tensor names.""" if self.region_index is not None: + children_regions = region.get_children(sort=True) assert self.region_index < len(children_regions), "Region index out of range" child_region = children_regions[self.region_index] - assert self.output_index < len(child_region.get_outputs()), "Output index out of range" - tensor_name = child_region.get_outputs()[self.output_index] - assert tensor_name is not None, "Invalid tensor name" - resolved_ips.update( - resolve_region_io_insertion_points(child_region, graph, tensor_name) - ) - # Handle node outputs (node_index is set) - elif self.node_index is not None: + assert self.output_index < len(child_region.outputs), "Output index out of range" + tensor_name = child_region.outputs[self.output_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + if self.node_index is not None: + node_indices = region.get_nodes(sort=True) assert self.node_index < len(node_indices), "Node index out of range" - node_idx = node_indices[self.node_index] - assert node_idx < len(nodes_list), "Node index out of range" - node = nodes_list[node_idx] + node = graph.nodes[node_indices[self.node_index]] assert self.output_index < len(node.outputs), "Output index out of range" tensor = node.outputs[self.output_index] - assert tensor is not None, "Invalid tensor name" assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" - resolved_ips.update(resolve_region_io_insertion_points(None, graph, tensor.name)) - return resolved_ips + return resolve_region_io_insertion_points(None, graph, tensor.name) + + return set() @staticmethod def collect_from_region( region: "Region", graph: gs.Graph ) -> list["RegionOutputInsertionPoint"]: - """Collect all valid region output insertion points from a region. - - Identifies all valid output tensors (from child regions or nodes) that leave - the region boundary and could have Q/DQ pairs inserted. Only includes outputs - that are actual region outputs (not consumed internally). - - For COMPOSITE regions: - - Collects child region outputs that are also region outputs - - Collects node outputs that are region outputs - - For LEAF regions: - - Only collects node outputs that are region outputs - - Args: - region: The region to collect insertion points from - graph: The ONNX graph containing the nodes - - Returns: - List of RegionOutputInsertionPoint objects representing valid insertion locations - """ + """Collect all valid region output insertion points from a region.""" from modelopt.onnx.quantization.autotune.common import RegionType - nodes_list = list(graph.nodes) - node_indices = sorted(region.get_nodes()) - region_outputs_set = set(region.get_outputs()) - - # Only include outputs that are actual region outputs (leave the region) - region_output_insertion_points = [] - if region.get_type() != RegionType.LEAF: - # For COMPOSITE regions: check if child region output is a region output - children_regions = region.get_children() - children_regions = sorted( - children_regions, key=lambda r: (-r.get_level(), r.get_total_size()) - ) - for local_idx, child_region in enumerate(children_regions): - for output_idx, out in enumerate(child_region.get_outputs()): - if out not in region_outputs_set: - continue - if skip_invalid_insertion_points(graph, out, child_region): - continue - point = RegionOutputInsertionPoint( - region_index=local_idx, - node_index=None, - output_index=output_idx, - ) - region_output_insertion_points.append(point) - # For all regions: check if node output is a region output - for local_idx, node_idx in enumerate(node_indices): - assert node_idx < len(nodes_list), "Node index out of range" - node = nodes_list[node_idx] + region_outputs_set = set(region.outputs) + insertion_points = [] + + # For COMPOSITE regions: collect child region outputs + if region.type != RegionType.LEAF: + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for output_idx, out in enumerate(child_region.outputs): + if out in region_outputs_set and not skip_invalid_insertion_points( + graph, out, child_region + ): + insertion_points.append( + RegionOutputInsertionPoint( + region_index=local_idx, node_index=None, output_index=output_idx + ) + ) + + # For all regions: collect node outputs + for local_idx, node_idx in enumerate(region.get_nodes(sort=True)): + node = graph.nodes[node_idx] for output_idx, out in enumerate(node.outputs): - # Skip if tensor doesn't have a valid name if not (hasattr(out, "name") and out.name): continue - # Skip if this output is not a region output (i.e., it's consumed internally) - if out.name not in region_outputs_set: - continue - # Skip if insertion point is invalid (wrong dtype, small size, etc.) - if skip_invalid_insertion_points(graph, out.name, node): - continue - # Create insertion point for valid output tensor - point = RegionOutputInsertionPoint( - region_index=None, - node_index=local_idx, - output_index=output_idx, - ) - region_output_insertion_points.append(point) - - return region_output_insertion_points - + if out.name in region_outputs_set and not skip_invalid_insertion_points( + graph, out.name, node + ): + insertion_points.append( + RegionOutputInsertionPoint( + region_index=None, node_index=local_idx, output_index=output_idx + ) + ) -InsertionPointType = ( - NodeInputInsertionPoint | ChildRegionInputInsertionPoint | RegionOutputInsertionPoint -) + return insertion_points def skip_invalid_insertion_points( graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" ) -> bool: - """Determine if a tensor should be skipped for Q/DQ insertion. - - Filters out tensors that are not suitable for quantization based on various criteria: - - Boolean and shape operations (not quantizable) - - Fused operation patterns (Conv->BatchNorm->ReLU) - - Operation-specific non-quantizable inputs (weights, biases, BN parameters) - - Non-floating-point tensors (indices, masks) - - Small tensors (scalars, small vectors with < 8 elements) - - Args: - graph: The ONNX graph containing the nodes - tensor_name: Name of the tensor to evaluate - region_or_node: Either a Region or a Node to check for usage of this tensor - - Returns: - True if the insertion point should be skipped, False if it's valid for quantization - """ + """Determine if a tensor should be skipped for Q/DQ insertion.""" from modelopt.onnx.quantization.autotune.common import Region if isinstance(region_or_node, Region): - node_indices = region_or_node.get_all_nodes_recursive() + node_indices = region_or_node.get_region_nodes_and_descendants() nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] else: assert isinstance(region_or_node, gs.Node) @@ -729,24 +315,19 @@ def skip_invalid_insertion_points( # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: return True - if node.op in ["Relu", "LeakyRelu", "Softmax"]: - # Conv -> ReLU/LeakyRelu/Softmax + # Conv -> ReLU/Softmax or Conv -> BatchNormalization -> ReLU/Softmax + if node.op in ["Relu", "Softmax"]: if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: producer = node.inputs[0].inputs[0] if producer.op in ["Conv", "ConvTranspose"]: return True - # Conv -> BatchNormalization -> ReLU/LeakyRelu/Softmax - if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: - producer = node.inputs[0].inputs[0] - if producer.op == "BatchNormalization": - assert len(producer.inputs) >= 1, ( - "BN node should have more than one inputs" - ) - if len(producer.inputs[0].inputs) == 1: - producer = producer.inputs[0].inputs[0] - if producer.op in ["Conv", "ConvTranspose"]: - return True - # Conv -> BatchNormalization -> ReLU/LeakyRelu/Softmax + if ( + producer.op == "BatchNormalization" + and len(producer.inputs[0].inputs) == 1 + and producer.inputs[0].inputs[0].op in ["Conv", "ConvTranspose"] + ): + return True + # Conv -> BatchNormalization if node.op == "BatchNormalization": assert len(node.inputs) >= 1, "BN node should have more than one inputs" if len(node.inputs[0].inputs) == 1: @@ -754,10 +335,18 @@ def skip_invalid_insertion_points( if producer.op in ["Conv", "ConvTranspose"]: return True # Filter 1: out boolean operations - if node.op in BOOL_OPERATIONS: + if node.op in ( + get_bool_ops() + | get_bitwise_ops() + | get_value_check_ops() + | get_comparison_ops() + | get_conditional_ops() + | get_aggregation_ops() + | get_set_ops() + ) or is_fusible_reduction_op(node.op): return True # Filter 2: out shape operations - if node.op in SHAPE_OPERATIONS: + if node.op in get_autotuner_skip_ops(): return True # Filter 3: Skip operation-specific non-quantizable inputs if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: @@ -781,76 +370,40 @@ def skip_invalid_insertion_points( def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: - """Check if a region contains major quantizable operations. - - Args: - region: The region to check - graph: The ONNX graph containing the nodes - - Returns: - True if the region contains major quantizable operations, False otherwise - """ + """Check if a region contains major quantizable operations (only checks LEAF regions).""" from modelopt.onnx.quantization.autotune.common import RegionType - # only check leaf regions for quantizable operations - if region.get_type() == RegionType.LEAF: - region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} - return bool(region_ops.intersection(MAJOR_QUANTIZABLE_OPERATIONS)) - return True + if region.type != RegionType.LEAF: + return True + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops & get_autotuner_quantizable_ops()) def resolve_region_io_insertion_points( region: "Region | None", graph: gs.Graph, tensor_name: str ) -> set[ResolvedInsertionPoint]: - """Resolve region input/output boundaries to actual Q/DQ insertion points. - - For a given tensor at a region boundary (input or output), this function - identifies all the actual node inputs where Q/DQ pairs should be inserted. - It considers both nodes within the region (if provided) and all users of - the tensor in the graph. - - **Use Cases:** - - Child region inputs: Find all nodes inside the child that consume the input tensor - - Child region outputs: Find all nodes outside the child that consume the output tensor - - Node outputs: Find all nodes that consume the tensor (region can be None) - - Args: - region: The region to search within (or None to search entire graph) - graph: The ONNX graph containing the nodes - tensor_name: Name of the tensor at the region boundary - - Returns: - Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs - """ - resolved_insertion_points = set() - tensor_users_map: dict[str, list[int]] = {} - if hasattr(graph, "tensor_users_map"): - tensor_users_map = graph.tensor_users_map - if not tensor_users_map: - tensor_users_map = get_tensor_consumer_node_indices(graph) + """Resolve region input/output boundaries to actual Q/DQ insertion points.""" + tensor_users_map = getattr(graph, "tensor_users_map", None) or get_tensor_consumer_node_indices( + graph + ) + node_indices: set[int] = set() if region is not None: - for node_idx in region.get_all_nodes_recursive(): - assert node_idx < len(graph.nodes), "Node index out of range" - node = graph.nodes[node_idx] - for input_idx, inp in enumerate(node.inputs): - if inp.name == tensor_name: - ip = ResolvedInsertionPoint( - tensor_name=tensor_name, node_index=node_idx, input_index=input_idx - ) - resolved_insertion_points.add(ip) + node_indices.update(region.get_region_nodes_and_descendants()) + node_indices.update(tensor_users_map.get(tensor_name, [])) - if tensor_name in tensor_users_map: - for node_idx in tensor_users_map[tensor_name]: - node = graph.nodes[node_idx] - for input_idx, inp in enumerate(node.inputs): - if inp.name == tensor_name: - ip = ResolvedInsertionPoint( - tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + resolved = set() + for node_idx in node_indices: + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + if not skip_invalid_insertion_points(graph, tensor_name, node): + resolved.add( + ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) ) - resolved_insertion_points.add(ip) - - return resolved_insertion_points + return resolved def merge_resolved_insertion_points( @@ -858,40 +411,59 @@ def merge_resolved_insertion_points( ) -> set[ResolvedInsertionPoint]: """Optimize insertion points by merging node-specific insertions into tensor-level insertions. - When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient - to insert Q/DQ once at the tensor level rather than at each individual node input. - This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. - - **Optimization Logic:** - - For each tensor with multiple node-specific insertion points: - - If ALL users of the tensor have insertion points → merge to tensor-level insertion - - If SOME users have insertion points → keep node-specific insertions - - Args: - graph: The ONNX graph containing the nodes - resolved_insertion_points: Set of resolved insertion points to optimize - - Returns: - Optimized set of insertion points with merged tensor-level insertions where possible + When all consumers of a tensor have Q/DQ insertion points, insert Q/DQ once at the + tensor level rather than at each individual node input. """ tensor_users_map = get_tensor_consumer_node_indices(graph) - node_input_insertion_points = { - ip for ip in resolved_insertion_points if ip.node_index is not None - } - tensor_names = {ip.tensor_name for ip in node_input_insertion_points} + node_ips = {ip for ip in resolved_insertion_points if ip.node_index is not None} - results = resolved_insertion_points.difference(node_input_insertion_points) - for tensor_name in tensor_names: - all_users = set(tensor_users_map[tensor_name]) - qdq_users = { - user for user in node_input_insertion_points if user.tensor_name == tensor_name - } - qdq_user_ids = set({user.node_index for user in qdq_users}) - if all_users == qdq_user_ids: + results = resolved_insertion_points - node_ips + for tensor_name in {ip.tensor_name for ip in node_ips}: + all_users = set(tensor_users_map.get(tensor_name, [])) + qdq_users = {ip for ip in node_ips if ip.tensor_name == tensor_name} + if all_users == {ip.node_index for ip in qdq_users}: results.add( ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) ) else: results.update(qdq_users) - return results + + +def get_autotuner_skip_ops(): + """Returns set of shape/structural operations that are not quantizable.""" + return set(get_copy_ops()) | { + # Additional indexing/scatter/reshape ops + "Compress", + "Scatter", + "ExpandDims", + "Unsqueeze", + "View", + "Pad", + # Utility ops + "Cast", + "Ceil", + "Clip", + "Identity", + "Range", + "Shape", + } + + +def get_autotuner_quantizable_ops(): + """Returns set of key operations that benefit from quantization.""" + return { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", + } diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py index d71524442..087bc8cfa 100644 --- a/tests/unit/onnx/quantization/autotune/test_insertion_points.py +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py index 714f8a051..df0cfaed3 100644 --- a/tests/unit/onnx/quantization/autotune/test_region.py +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # From 79496ea0a228da93fcbeb9222b0746e5e227b120 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Tue, 27 Jan 2026 10:56:27 +0000 Subject: [PATCH 3/5] recover docstring and many fix Signed-off-by: Will Guo --- modelopt/onnx/quantization/autotune/common.py | 88 +++---- .../quantization/autotune/insertion_points.py | 144 ++++++++--- .../autotune/test_insertion_points.py | 238 ++++++++---------- .../onnx/quantization/autotune/test_region.py | 24 +- 4 files changed, 264 insertions(+), 230 deletions(-) diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py index cab73039a..a8929315a 100644 --- a/modelopt/onnx/quantization/autotune/common.py +++ b/modelopt/onnx/quantization/autotune/common.py @@ -20,13 +20,11 @@ from enum import Enum from typing import Any -import onnx_graphsurgeon as gs - from modelopt.onnx.logging_config import logger from modelopt.onnx.quantization.autotune.insertion_points import ( ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, NodeInputInsertionPoint, - RegionOutputInsertionPoint, ) @@ -83,7 +81,14 @@ def __init__(self, region_id: int, level: int, region_type: RegionType): self.metadata: dict[str, str] = {} def get_children(self, *, sort: bool = False) -> list["Region"]: - """Get all child regions.""" + """Get all child regions. If sort is True, sort the children by level and size. + + Args: + sort: Whether to sort the children by level and size + + Returns: + List of child regions + """ if sort: return sorted( self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants()) @@ -138,14 +143,6 @@ def is_descendant_of(self, potential_ancestor: "Region") -> bool: current = current.parent return False - def add_node(self, node_index: int) -> None: - """Add a node index to this region.""" - self.nodes.add(node_index) - - def add_nodes(self, node_indices: list[int]) -> None: - """Add multiple node indices to this region.""" - self.nodes.update(node_indices) - def get_nodes(self, *, sort: bool = False) -> list[int]: """Get direct node indices in this region only.""" if sort: @@ -174,16 +171,6 @@ def contains_node_within_region_and_descendants(self, node_index: int) -> bool: """Check if region contains a node recursively.""" return node_index in self.get_region_nodes_and_descendants() - def add_input(self, tensor_name: str) -> None: - """Add an input tensor name.""" - if tensor_name not in self.inputs: - self.inputs.append(tensor_name) - - def add_output(self, tensor_name: str) -> None: - """Add an output tensor name.""" - if tensor_name not in self.outputs: - self.outputs.append(tensor_name) - def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: """Get total node count recursively including all descendants.""" if _visited is None: @@ -216,41 +203,31 @@ def __repr__(self) -> str: f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" ) - def compute_structural_signature(self, graph: gs.Graph) -> str: - """Compute deterministic structural signature for pattern matching. - - Creates a signature that uniquely identifies the region's topology, - node operations, and hierarchical structure. Regions with identical - signatures can share Q/DQ insertion schemes. - - The signature captures: - - Node operation types and key parameters - - Hierarchical structure (child regions) - - Deterministic ordering (sorted for consistency) - - Args: - graph: The ONNX graph containing the region's nodes - - Returns: - Signature string (e.g., "Conv->BatchNorm->Relu" or "COMPOSITE(...)") - """ - raise NotImplementedError("Not implemented") - @dataclass class InsertionScheme: - """Q/DQ insertion specification applied to all regions matching a pattern.""" + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + """ node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) - region_outputs: list[RegionOutputInsertionPoint] = field(default_factory=list) + region_outputs: list[ChildRegionOutputInsertionPoint] = field(default_factory=list) latency_ms: float = float("inf") error: bool = False profile_timestamp: str | None = None @property def hash(self) -> str: - """Compute deterministic hash for scheme identity.""" + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + """ sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) sorted_regions = sorted( [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] @@ -270,7 +247,11 @@ def is_empty(self) -> bool: @property def is_profiled(self) -> bool: - """Check if this scheme has been profiled (measured).""" + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + """ return self.error or self.latency_ms != float("inf") def to_dict(self) -> dict[str, Any]: @@ -301,13 +282,24 @@ def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": for pt in data.get("child_region_inputs", []) ] scheme.region_outputs = [ - RegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ChildRegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) ] return scheme def distance(self, other: "InsertionScheme") -> int: - """Compute edit distance between this scheme and another scheme.""" + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + """ return ( len(set(self.node_inputs).symmetric_difference(other.node_inputs)) + len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs)) diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py index 0be87b781..dd01848dd 100644 --- a/modelopt/onnx/quantization/autotune/insertion_points.py +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -13,10 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Q/DQ insertion point management for ONNX quantization autotune.""" +"""Q/DQ insertion point management for ONNX quantization autotune. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. +""" from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, Any import numpy as np @@ -67,7 +72,14 @@ def collect_from_region(region: "Region", graph: gs.Graph) -> list["InsertionPoi @dataclass(frozen=True) class ResolvedInsertionPoint: - """Resolved Q/DQ insertion point with actual tensor name and optional node context.""" + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. It contains the + tensor name and the node index (if applicable) and input index (if applicable). + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ tensor_name: str node_index: int | None = None # Absolute graph node index (or None for tensor-level insertion) @@ -75,32 +87,31 @@ class ResolvedInsertionPoint: def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" - return { - "tensor_name": self.tensor_name, - "node_index": self.node_index, - "input_index": self.input_index, - } + return asdict(self) @classmethod def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": """Create from dictionary.""" - return cls( - tensor_name=data["tensor_name"], - node_index=data["node_index"], - input_index=data.get("input_index"), - ) + return cls(**data) @dataclass(frozen=True) class NodeInputInsertionPoint(InsertionPoint): - """Pattern-relative Q/DQ insertion point at a node's input (frozen/hashable).""" + """Pattern-relative Q/DQ insertion point at a node's input (frozen/hashable). + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ node_index: int # Pattern-relative node index input_index: int # Input tensor index of that node def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" - return {"node_index": self.node_index, "input_index": self.input_index} + return asdict(self) @classmethod def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": @@ -143,9 +154,8 @@ def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputIns for local_idx, node_idx in enumerate(node_indices): node = graph.nodes[node_idx] for input_idx, inp in enumerate(node.inputs): - if not (hasattr(inp, "name") and inp.name): - continue - if skip_invalid_insertion_points(graph, inp.name, node): + name = getattr(inp, "name", None) + if not name or skip_invalid_insertion_points(graph, name, node): continue insertion_points.append( NodeInputInsertionPoint(node_index=local_idx, input_index=input_idx) @@ -157,7 +167,14 @@ def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputIns class ChildRegionInputInsertionPoint(InsertionPoint): """Pattern-relative Q/DQ insertion point at a child region's input boundary (frozen/hashable). + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + Only applies to COMPOSITE regions; LEAF regions have no children. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. """ # Pattern-relative child region index @@ -167,12 +184,12 @@ class ChildRegionInputInsertionPoint(InsertionPoint): def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" - return {"region_index": self.region_index, "input_index": self.input_index} + return asdict(self) @classmethod def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": """Create from dictionary.""" - return cls(region_index=data["region_index"], input_index=data["input_index"]) + return cls(**data) def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: """Resolve a child region input insertion point to actual tensor names.""" @@ -210,8 +227,15 @@ def collect_from_region( @dataclass(frozen=True) -class RegionOutputInsertionPoint(InsertionPoint): - """Pattern-relative Q/DQ insertion point at a child region or node output (frozen/hashable).""" +class ChildRegionOutputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region or node output (frozen/hashable). + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ region_index: int | None # Pattern-relative child region index (or None) node_index: int | None # Pattern-relative node index (or None) @@ -219,20 +243,12 @@ class RegionOutputInsertionPoint(InsertionPoint): def to_dict(self) -> dict[str, Any]: """Convert to dictionary for serialization.""" - return { - "region_index": self.region_index, - "node_index": self.node_index, - "output_index": self.output_index, - } + return asdict(self) @classmethod - def from_dict(cls, data: dict[str, Any]) -> "RegionOutputInsertionPoint": + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionOutputInsertionPoint": """Create from dictionary.""" - return cls( - region_index=data.get("region_index"), - node_index=data.get("node_index"), - output_index=data["output_index"], - ) + return cls(**data) def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: """Resolve a region output insertion point to actual tensor names.""" @@ -258,7 +274,7 @@ def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoi @staticmethod def collect_from_region( region: "Region", graph: gs.Graph - ) -> list["RegionOutputInsertionPoint"]: + ) -> list["ChildRegionOutputInsertionPoint"]: """Collect all valid region output insertion points from a region.""" from modelopt.onnx.quantization.autotune.common import RegionType @@ -273,7 +289,7 @@ def collect_from_region( graph, out, child_region ): insertion_points.append( - RegionOutputInsertionPoint( + ChildRegionOutputInsertionPoint( region_index=local_idx, node_index=None, output_index=output_idx ) ) @@ -288,7 +304,7 @@ def collect_from_region( graph, out.name, node ): insertion_points.append( - RegionOutputInsertionPoint( + ChildRegionOutputInsertionPoint( region_index=None, node_index=local_idx, output_index=output_idx ) ) @@ -299,7 +315,23 @@ def collect_from_region( def skip_invalid_insertion_points( graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" ) -> bool: - """Determine if a tensor should be skipped for Q/DQ insertion.""" + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ from modelopt.onnx.quantization.autotune.common import Region if isinstance(region_or_node, Region): @@ -370,7 +402,15 @@ def skip_invalid_insertion_points( def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: - """Check if a region contains major quantizable operations (only checks LEAF regions).""" + """Check if a region contains major quantizable operations (only checks LEAF regions). + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ from modelopt.onnx.quantization.autotune.common import RegionType if region.type != RegionType.LEAF: @@ -382,7 +422,21 @@ def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: def resolve_region_io_insertion_points( region: "Region | None", graph: gs.Graph, tensor_name: str ) -> set[ResolvedInsertionPoint]: - """Resolve region input/output boundaries to actual Q/DQ insertion points.""" + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ tensor_users_map = getattr(graph, "tensor_users_map", None) or get_tensor_consumer_node_indices( graph ) @@ -411,8 +465,16 @@ def merge_resolved_insertion_points( ) -> set[ResolvedInsertionPoint]: """Optimize insertion points by merging node-specific insertions into tensor-level insertions. - When all consumers of a tensor have Q/DQ insertion points, insert Q/DQ once at the - tensor level rather than at each individual node input. + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible """ tensor_users_map = get_tensor_consumer_node_indices(graph) node_ips = {ip for ip in resolved_insertion_points if ip.node_index is not None} diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py index 087bc8cfa..19e7203e9 100644 --- a/tests/unit/onnx/quantization/autotune/test_insertion_points.py +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -17,7 +17,7 @@ Comprehensive tests for common data structures in the autotuner. Tests: -1. InsertionPoint classes (NodeInputInsertionPoint, RegionOutputInsertionPoint, ChildRegionInputInsertionPoint) +1. InsertionPoint classes (NodeInputInsertionPoint, ChildRegionOutputInsertionPoint, ChildRegionInputInsertionPoint) 2. InsertionScheme serialization/deserialization 3. InsertionScheme hashing and equality 4. InsertionScheme properties and methods @@ -34,10 +34,10 @@ from modelopt.onnx.quantization.autotune.common import ( ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, InsertionScheme, NodeInputInsertionPoint, Region, - RegionOutputInsertionPoint, RegionType, ) from modelopt.onnx.quantization.autotune.insertion_points import ( @@ -106,83 +106,83 @@ def test_string_representation(self): assert "1" in s -class TestRegionOutputInsertionPoint(unittest.TestCase): - """Test RegionOutputInsertionPoint functionality.""" +class TestChildRegionOutputInsertionPoint(unittest.TestCase): + """Test ChildRegionOutputInsertionPoint functionality.""" def test_creation_with_region_index(self): """Test creating with region_index (child region output).""" - point = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + point = ChildRegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) assert point.region_index == 2 assert point.node_index is None assert point.output_index == 1 def test_creation_with_node_index(self): """Test creating with node_index (node output).""" - point = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + point = ChildRegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) assert point.region_index is None assert point.node_index == 5 assert point.output_index == 0 def test_immutability(self): - """Test that RegionOutputInsertionPoint is immutable (frozen).""" - point = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + """Test that ChildRegionOutputInsertionPoint is immutable (frozen).""" + point = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) passed = False try: point.region_index = 2 except AttributeError: passed = True - assert passed, "RegionOutputInsertionPoint should be immutable" + assert passed, "ChildRegionOutputInsertionPoint should be immutable" def test_equality(self): """Test equality comparison.""" - point1 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point2 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point3 = RegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) + point1 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point2 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point3 = ChildRegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) assert point1 == point2 assert point1 != point3 def test_hashable(self): """Test that points can be used in sets and dicts.""" - point1 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point2 = RegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point3 = RegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) + point1 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point2 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) + point3 = ChildRegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) point_set = {point1, point2, point3} assert len(point_set) == 2 # point1 and point2 are the same def test_serialization_region_index(self): """Test serialization with region_index.""" - point = RegionOutputInsertionPoint(region_index=3, node_index=None, output_index=2) + point = ChildRegionOutputInsertionPoint(region_index=3, node_index=None, output_index=2) data = point.to_dict() assert data["region_index"] == 3 assert data["node_index"] is None assert data["output_index"] == 2 - restored = RegionOutputInsertionPoint.from_dict(data) + restored = ChildRegionOutputInsertionPoint.from_dict(data) assert point == restored def test_serialization_node_index(self): """Test serialization with node_index.""" - point = RegionOutputInsertionPoint(region_index=None, node_index=7, output_index=1) + point = ChildRegionOutputInsertionPoint(region_index=None, node_index=7, output_index=1) data = point.to_dict() assert data["region_index"] is None assert data["node_index"] == 7 assert data["output_index"] == 1 - restored = RegionOutputInsertionPoint.from_dict(data) + restored = ChildRegionOutputInsertionPoint.from_dict(data) assert point == restored def test_string_representation(self): """Test __str__ method.""" - point1 = RegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) + point1 = ChildRegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) s1 = str(point1) assert "region" in s1.lower() assert "2" in s1 - point2 = RegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) + point2 = ChildRegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) s2 = str(point2) assert "node" in s2.lower() assert "5" in s2 @@ -269,8 +269,8 @@ def test_scheme_with_region_outputs(self): """Test scheme with region output insertion points.""" scheme = InsertionScheme() scheme.region_outputs = [ - RegionOutputInsertionPoint(None, 0, 0), - RegionOutputInsertionPoint(1, None, 0), + ChildRegionOutputInsertionPoint(None, 0, 0), + ChildRegionOutputInsertionPoint(1, None, 0), ] assert not scheme.is_empty @@ -341,7 +341,7 @@ def test_serialization_full(self): scheme = InsertionScheme() scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] - scheme.region_outputs = [RegionOutputInsertionPoint(None, 0, 0)] + scheme.region_outputs = [ChildRegionOutputInsertionPoint(None, 0, 0)] scheme.latency_ms = 12.5 scheme.error = False @@ -367,11 +367,6 @@ def test_serialization_with_error(self): assert restored.latency_ms == float("inf") -# ============================================================================= -# Helper functions for creating mock graphs -# ============================================================================= - - def _create_mock_tensor(name: str, dtype=np.float32, shape=None): """Create a mock tensor with the specified properties.""" tensor = MagicMock() @@ -542,11 +537,6 @@ def _create_residual_graph(): return graph, tensors -# ============================================================================= -# Utility Function Tests -# ============================================================================= - - class TestSkipInvalidInsertionPoints(unittest.TestCase): """Test skip_invalid_insertion_points function.""" @@ -639,14 +629,14 @@ def test_with_region(self): # Create a region containing Conv and BatchNorm nodes region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv node - region.add_node(1) # BatchNorm node + region.nodes.add(0) # Conv node + region.nodes.add(1) # BatchNorm node # Create a shape operation node and add to graph shape_tensor = _create_mock_tensor("shape_input", np.float32) shape_node = _create_mock_node("Shape", [shape_tensor], []) graph.nodes.append(shape_node) - region.add_node(4) # Add the shape node to region + region.nodes.add(4) # Add the shape node to region result = skip_invalid_insertion_points(graph, "shape_input", region) assert result is True @@ -682,7 +672,7 @@ def test_leaf_with_conv(self): graph, _ = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv node + region.nodes.add(0) # Conv node result = has_quantizable_operations(region, graph) assert result is True @@ -692,7 +682,7 @@ def test_leaf_with_maxpool(self): graph, _ = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(3) # MaxPool node + region.nodes.add(3) # MaxPool node result = has_quantizable_operations(region, graph) assert result is True @@ -702,7 +692,7 @@ def test_leaf_with_relu_only(self): graph, _ = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(2) # Relu node only (index 2 in new graph) + region.nodes.add(2) # Relu node only (index 2 in new graph) result = has_quantizable_operations(region, graph) assert result is True # Relu is in MAJOR_QUANTIZABLE_OPERATIONS @@ -712,9 +702,9 @@ def test_leaf_with_conv_bn_relu(self): graph, _ = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv - region.add_node(1) # BatchNorm - region.add_node(2) # Relu + region.nodes.add(0) # Conv + region.nodes.add(1) # BatchNorm + region.nodes.add(2) # Relu result = has_quantizable_operations(region, graph) assert result is True @@ -731,8 +721,8 @@ def test_leaf_without_quantizable_ops(self): graph.nodes = [shape_node, transpose_node] region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) - region.add_node(1) + region.nodes.add(0) + region.nodes.add(1) result = has_quantizable_operations(region, graph) assert result is False @@ -752,7 +742,7 @@ def test_residual_block_has_quantizable_ops(self): graph, _ = _create_residual_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(3) # Add node + region.nodes.add(3) # Add node result = has_quantizable_operations(region, graph) assert result is True # Add is in MAJOR_QUANTIZABLE_OPERATIONS @@ -769,7 +759,7 @@ def test_resolve_with_region(self): graph.tensor_users_map = get_tensor_consumer_node_indices(graph) region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(2) # Relu node + region.nodes.add(2) # Relu node result = resolve_region_io_insertion_points(region, graph, "relu_out") @@ -823,7 +813,7 @@ def test_resolve_with_multiple_consumers(self): graph.tensor_users_map = {"relu1_out": [2]} region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(2) # Conv2 + region.nodes.add(2) # Conv2 result = resolve_region_io_insertion_points(region, graph, "relu1_out") @@ -948,11 +938,6 @@ def test_no_merge_residual_partial(self): assert ip.node_index == 0 # Still node-specific -# ============================================================================= -# Resolve Method Tests -# ============================================================================= - - class TestNodeInputInsertionPointResolve(unittest.TestCase): """Test NodeInputInsertionPoint.resolve() method.""" @@ -961,10 +946,10 @@ def test_resolve_simple(self): graph, tensors = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv node - region.add_node(1) # BatchNorm node - region.add_node(2) # Relu node - region.add_node(3) # MaxPool node + region.nodes.add(0) # Conv node + region.nodes.add(1) # BatchNorm node + region.nodes.add(2) # Relu node + region.nodes.add(3) # MaxPool node # Create insertion point for first input of first node (Conv) ip = NodeInputInsertionPoint(node_index=0, input_index=0) @@ -979,7 +964,7 @@ def test_resolve_conv_includes_weight(self): graph, tensors = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv node + region.nodes.add(0) # Conv node # Create insertion point for first input of Conv (should also add weight) ip = NodeInputInsertionPoint(node_index=0, input_index=0) @@ -997,9 +982,9 @@ def test_resolve_relu_input(self): graph, tensors = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv - region.add_node(1) # BatchNorm - region.add_node(2) # Relu + region.nodes.add(0) # Conv + region.nodes.add(1) # BatchNorm + region.nodes.add(2) # Relu # Relu is at local index 2, input 0 is bn_out ip = NodeInputInsertionPoint(node_index=2, input_index=0) @@ -1015,9 +1000,9 @@ def test_resolve_residual_conv_input(self): graph, tensors = _create_residual_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv1 - region.add_node(1) # Relu1 - region.add_node(2) # Conv2 + region.nodes.add(0) # Conv1 + region.nodes.add(1) # Relu1 + region.nodes.add(2) # Conv2 # Conv2 is at local index 2, input 0 is relu1_out ip = NodeInputInsertionPoint(node_index=2, input_index=0) @@ -1043,9 +1028,9 @@ def test_resolve_composite_region(self): parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) child = Region(region_id=2, level=0, region_type=RegionType.LEAF) child.inputs = ["input"] - child.add_node(0) # Conv - child.add_node(1) # BatchNorm - child.add_node(2) # Relu + child.nodes.add(0) # Conv + child.nodes.add(1) # BatchNorm + child.nodes.add(2) # Relu parent.add_child(child) ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) @@ -1060,7 +1045,7 @@ def test_resolve_leaf_returns_empty(self): graph, _ = _create_simple_graph() leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) - leaf.add_node(0) + leaf.nodes.add(0) ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) @@ -1080,12 +1065,12 @@ def test_resolve_multiple_children(self): # First child: Conv1 (consumes "input") child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) child1.inputs = ["input"] - child1.add_node(0) # Conv1 + child1.nodes.add(0) # Conv1 # Second child: Relu1 (consumes "relu1_out") child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) child2.inputs = ["relu1_out"] - child2.add_node(2) # Relu1 + child2.nodes.add(2) # Relu1 parent.add_child(child1) parent.add_child(child2) @@ -1105,8 +1090,8 @@ def test_resolve_multiple_children(self): assert any(rip.tensor_name == "relu1_out" for rip in result2) -class TestRegionOutputInsertionPointResolve(unittest.TestCase): - """Test RegionOutputInsertionPoint.resolve() method.""" +class TestChildRegionOutputInsertionPointResolve(unittest.TestCase): + """Test ChildRegionOutputInsertionPoint.resolve() method.""" def test_resolve_node_output(self): """Test resolving a node output.""" @@ -1114,14 +1099,14 @@ def test_resolve_node_output(self): graph.tensor_users_map = get_tensor_consumer_node_indices(graph) region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv - region.add_node(1) # BatchNorm - region.add_node(2) # Relu - region.add_node(3) # MaxPool + region.nodes.add(0) # Conv + region.nodes.add(1) # BatchNorm + region.nodes.add(2) # Relu + region.nodes.add(3) # MaxPool region.outputs = ["pool_out"] # Output of last node (MaxPool) - ip = RegionOutputInsertionPoint(region_index=None, node_index=2, output_index=0) + ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=2, output_index=0) result = ip.resolve(region, graph) @@ -1136,12 +1121,12 @@ def test_resolve_child_region_output(self): parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) child = Region(region_id=2, level=0, region_type=RegionType.LEAF) child.outputs = ["relu_out"] - child.add_node(0) # Conv - child.add_node(1) # BatchNorm - child.add_node(2) # Relu + child.nodes.add(0) # Conv + child.nodes.add(1) # BatchNorm + child.nodes.add(2) # Relu parent.add_child(child) - ip = RegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + ip = ChildRegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) result = ip.resolve(parent, graph) @@ -1154,15 +1139,15 @@ def test_resolve_residual_add_output(self): graph.tensor_users_map = {"add_out": [4]} region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv1 - region.add_node(1) # Relu1 - region.add_node(2) # Conv2 - region.add_node(3) # Add - region.add_node(4) # Relu2 + region.nodes.add(0) # Conv1 + region.nodes.add(1) # Relu1 + region.nodes.add(2) # Conv2 + region.nodes.add(3) # Add + region.nodes.add(4) # Relu2 region.outputs = ["add_out"] # Add is at local index 3, output 0 - ip = RegionOutputInsertionPoint(region_index=None, node_index=3, output_index=0) + ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=3, output_index=0) result = ip.resolve(region, graph) @@ -1170,11 +1155,6 @@ def test_resolve_residual_add_output(self): assert any(rip.tensor_name == "add_out" for rip in result) -# ============================================================================= -# Collect From Region Tests -# ============================================================================= - - class TestNodeInputInsertionPointCollectFrom(unittest.TestCase): """Test NodeInputInsertionPoint.collect_from_region() method.""" @@ -1183,10 +1163,10 @@ def test_collect_valid_inputs(self): graph, tensors = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv - region.add_node(1) # BatchNorm - region.add_node(2) # Relu - region.add_node(3) # MaxPool + region.nodes.add(0) # Conv + region.nodes.add(1) # BatchNorm + region.nodes.add(2) # Relu + region.nodes.add(3) # MaxPool result = NodeInputInsertionPoint.collect_from_region(region, graph) @@ -1200,11 +1180,11 @@ def test_collect_from_residual_block(self): graph, tensors = _create_residual_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv1 - region.add_node(1) # Relu1 - region.add_node(2) # Conv2 - region.add_node(3) # Add - region.add_node(4) # Relu2 + region.nodes.add(0) # Conv1 + region.nodes.add(1) # Relu1 + region.nodes.add(2) # Conv2 + region.nodes.add(3) # Add + region.nodes.add(4) # Relu2 result = NodeInputInsertionPoint.collect_from_region(region, graph) @@ -1227,9 +1207,9 @@ def test_collect_from_composite(self): parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) child = Region(region_id=2, level=0, region_type=RegionType.LEAF) child.inputs = ["input"] - child.add_node(0) # Conv - child.add_node(1) # BatchNorm - child.add_node(2) # Relu + child.nodes.add(0) # Conv + child.nodes.add(1) # BatchNorm + child.nodes.add(2) # Relu parent.add_child(child) result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) @@ -1243,7 +1223,7 @@ def test_collect_from_leaf_returns_empty(self): graph, _ = _create_simple_graph() leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) - leaf.add_node(0) + leaf.nodes.add(0) result = ChildRegionInputInsertionPoint.collect_from_region(leaf, graph) @@ -1257,13 +1237,13 @@ def test_collect_from_composite_with_multiple_children(self): child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) child1.inputs = ["input"] - child1.add_node(0) # Conv1 - child1.add_node(1) # Relu1 + child1.nodes.add(0) # Conv1 + child1.nodes.add(1) # Relu1 child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) child2.inputs = ["relu1_out", "input"] # Two inputs including skip connection - child2.add_node(2) # Conv2 - child2.add_node(3) # Add + child2.nodes.add(2) # Conv2 + child2.nodes.add(3) # Add parent.add_child(child1) parent.add_child(child2) @@ -1274,25 +1254,25 @@ def test_collect_from_composite_with_multiple_children(self): assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) -class TestRegionOutputInsertionPointCollectFrom(unittest.TestCase): - """Test RegionOutputInsertionPoint.collect_from_region() method.""" +class TestChildRegionOutputInsertionPointCollectFrom(unittest.TestCase): + """Test ChildRegionOutputInsertionPoint.collect_from_region() method.""" def test_collect_node_outputs(self): """Test collecting node output insertion points.""" graph, tensors = _create_simple_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv - region.add_node(1) # BatchNorm - region.add_node(2) # Relu - region.add_node(3) # MaxPool + region.nodes.add(0) # Conv + region.nodes.add(1) # BatchNorm + region.nodes.add(2) # Relu + region.nodes.add(3) # MaxPool region.outputs = ["pool_out"] # Only pool_out is a region output - result = RegionOutputInsertionPoint.collect_from_region(region, graph) + result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) # Should find the node output that matches region output assert len(result) >= 0 # May be filtered - assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) def test_collect_child_region_outputs(self): """Test collecting child region output insertion points.""" @@ -1301,30 +1281,30 @@ def test_collect_child_region_outputs(self): parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) child = Region(region_id=2, level=0, region_type=RegionType.LEAF) child.outputs = ["relu_out"] - child.add_node(0) # Conv - child.add_node(1) # BatchNorm - child.add_node(2) # Relu + child.nodes.add(0) # Conv + child.nodes.add(1) # BatchNorm + child.nodes.add(2) # Relu parent.add_child(child) parent.outputs = ["relu_out"] # Child output is also parent output - result = RegionOutputInsertionPoint.collect_from_region(parent, graph) + result = ChildRegionOutputInsertionPoint.collect_from_region(parent, graph) # Should find the child region output - assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) def test_collect_residual_block_outputs(self): """Test collecting outputs from residual block.""" graph, tensors = _create_residual_graph() region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) # Conv1 - region.add_node(1) # Relu1 - region.add_node(2) # Conv2 - region.add_node(3) # Add - region.add_node(4) # Relu2 + region.nodes.add(0) # Conv1 + region.nodes.add(1) # Relu1 + region.nodes.add(2) # Conv2 + region.nodes.add(3) # Add + region.nodes.add(4) # Relu2 region.outputs = ["output"] # Final output - result = RegionOutputInsertionPoint.collect_from_region(region, graph) + result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) # Should find the output - assert all(isinstance(ip, RegionOutputInsertionPoint) for ip in result) + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py index df0cfaed3..297d5a2be 100644 --- a/tests/unit/onnx/quantization/autotune/test_region.py +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -61,9 +61,9 @@ def test_add_nodes(self): """Test adding nodes to a region.""" region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.add_node(0) - region.add_node(1) - region.add_node(2) + region.nodes.add(0) + region.nodes.add(1) + region.nodes.add(2) assert len(region.nodes) == 3 assert 0 in region.get_nodes() @@ -90,18 +90,18 @@ def test_region_size_recursive(self): child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) # Add nodes to children - child1.add_node(0) - child1.add_node(1) - child2.add_node(2) - child2.add_node(3) - child2.add_node(4) + child1.nodes.add(0) + child1.nodes.add(1) + child2.nodes.add(2) + child2.nodes.add(3) + child2.nodes.add(4) # Add children to parent parent.add_child(child1) parent.add_child(child2) # Parent itself might have direct nodes - parent.add_node(5) + parent.nodes.add(5) # Recursive count should include all nodes assert len(parent.get_region_nodes_and_descendants()) == 6 @@ -143,9 +143,9 @@ def test_hierarchical_structure(self): composite2.add_child(leaf3) # Add some nodes - leaf1.add_node(0) - leaf2.add_node(1) - leaf3.add_node(2) + leaf1.nodes.add(0) + leaf2.nodes.add(1) + leaf3.nodes.add(2) # Verify structure assert len(root.get_children()) == 2 From 0ca17a23802ecd88401f4bf4b4b0e39ccf0d390e Mon Sep 17 00:00:00 2001 From: Will Guo Date: Tue, 27 Jan 2026 23:50:58 +0000 Subject: [PATCH 4/5] fix get_tensor_consumer_node_indices bug Signed-off-by: Will Guo --- modelopt/onnx/quantization/graph_utils.py | 27 +++++++++-------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index f05a08bfa..efa77dd7b 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -307,27 +307,20 @@ def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[ Args: graph: ONNX GraphSurgeon graph to analyze - Returns: Dictionary mapping tensor names to lists of node indices that consume them """ tensor_consumer_map: dict[str, list[int]] = defaultdict(list) - - if isinstance(graph, gs.Graph): - for node_idx, node in enumerate(graph.nodes): - for t in node.inputs: - name = getattr(t, "name", None) - if not name: - continue - tensor_consumer_map[name].append(node_idx) - return tensor_consumer_map - - # onnx.GraphProto case: node.input is repeated string - for node_idx, node in enumerate(graph.node): - for input_name in node.input: - if not input_name: - continue - tensor_consumer_map[input_name].append(node_idx) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_name = tensor + if isinstance(tensor, str): + tensor_name = tensor + elif hasattr(tensor, "name") and isinstance(tensor.name, str): + tensor_name = tensor.name + tensor_consumer_map[tensor_name].append(node_idx) return tensor_consumer_map From 09a91a8301d8e3cf37510de5e955b8f19c4335b6 Mon Sep 17 00:00:00 2001 From: Will Guo Date: Thu, 29 Jan 2026 09:00:23 +0000 Subject: [PATCH 5/5] parameterize tests Signed-off-by: Will Guo --- .../autotune/test_insertion_points.py | 992 ++++++------------ .../onnx/quantization/autotune/test_region.py | 240 ++--- 2 files changed, 408 insertions(+), 824 deletions(-) diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py index 19e7203e9..2818d3172 100644 --- a/tests/unit/onnx/quantization/autotune/test_insertion_points.py +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -31,6 +31,7 @@ import numpy as np import onnx_graphsurgeon as gs +import pytest from modelopt.onnx.quantization.autotune.common import ( ChildRegionInputInsertionPoint, @@ -49,322 +50,181 @@ ) from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices +INSERTION_POINT_CASES = [ + pytest.param( + NodeInputInsertionPoint, + {"node_index": 5, "input_index": 2}, + {"node_index": 5, "input_index": 2}, + {"node_index": 5, "input_index": 3}, + "node_index", + ["5", "2"], + id="NodeInputInsertionPoint", + ), + pytest.param( + ChildRegionOutputInsertionPoint, + {"region_index": 2, "node_index": None, "output_index": 1}, + {"region_index": 2, "node_index": None, "output_index": 1}, + {"region_index": None, "node_index": 2, "output_index": 1}, + "region_index", + ["region", "2"], + id="ChildRegionOutputInsertionPoint-region", + ), + pytest.param( + ChildRegionOutputInsertionPoint, + {"region_index": None, "node_index": 5, "output_index": 0}, + {"region_index": None, "node_index": 5, "output_index": 0}, + {"region_index": None, "node_index": 5, "output_index": 1}, + "node_index", + ["node", "5"], + id="ChildRegionOutputInsertionPoint-node", + ), + pytest.param( + ChildRegionInputInsertionPoint, + {"region_index": 3, "input_index": 1}, + {"region_index": 3, "input_index": 1}, + {"region_index": 3, "input_index": 2}, + "region_index", + ["3", "1"], + id="ChildRegionInputInsertionPoint", + ), +] + + +class TestInsertionPoints: + """Combined tests for all InsertionPoint types.""" + + @pytest.mark.parametrize(("cls", "kwargs", "_", "__", "___", "____"), INSERTION_POINT_CASES) + def test_creation(self, cls, kwargs, _, __, ___, ____): + point = cls(**kwargs) + for key, val in kwargs.items(): + assert getattr(point, key) == val + + @pytest.mark.parametrize( + ("cls", "kwargs", "_", "__", "mutate_attr", "___"), INSERTION_POINT_CASES + ) + def test_immutability(self, cls, kwargs, _, __, mutate_attr, ___): + point = cls(**kwargs) + with pytest.raises(AttributeError): + setattr(point, mutate_attr, 999) -class TestNodeInputInsertionPoint(unittest.TestCase): - """Test NodeInputInsertionPoint functionality.""" - - def test_creation(self): - """Test creating NodeInputInsertionPoint.""" - point = NodeInputInsertionPoint(node_index=5, input_index=2) - assert point.node_index == 5 - assert point.input_index == 2 - - def test_immutability(self): - """Test that NodeInputInsertionPoint is immutable (frozen).""" - point = NodeInputInsertionPoint(node_index=1, input_index=0) - passed = False - try: - point.node_index = 2 - except AttributeError: - passed = True - assert passed, "NodeInputInsertionPoint should be immutable" - - def test_equality(self): - """Test equality comparison.""" - point1 = NodeInputInsertionPoint(node_index=3, input_index=1) - point2 = NodeInputInsertionPoint(node_index=3, input_index=1) - point3 = NodeInputInsertionPoint(node_index=3, input_index=2) - - assert point1 == point2 - assert point1 != point3 - - def test_hashable(self): - """Test that points can be used in sets and dicts.""" - point1 = NodeInputInsertionPoint(node_index=1, input_index=0) - point2 = NodeInputInsertionPoint(node_index=1, input_index=0) - point3 = NodeInputInsertionPoint(node_index=2, input_index=0) - - point_set = {point1, point2, point3} - assert len(point_set) == 2 # point1 and point2 are the same - - def test_serialization(self): - """Test to_dict and from_dict.""" - point = NodeInputInsertionPoint(node_index=7, input_index=3) - - data = point.to_dict() - assert data["node_index"] == 7 - assert data["input_index"] == 3 - - restored = NodeInputInsertionPoint.from_dict(data) - assert point == restored - - def test_string_representation(self): - """Test __str__ method.""" - point = NodeInputInsertionPoint(node_index=2, input_index=1) - s = str(point) - assert "2" in s - assert "1" in s - - -class TestChildRegionOutputInsertionPoint(unittest.TestCase): - """Test ChildRegionOutputInsertionPoint functionality.""" - - def test_creation_with_region_index(self): - """Test creating with region_index (child region output).""" - point = ChildRegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) - assert point.region_index == 2 - assert point.node_index is None - assert point.output_index == 1 - - def test_creation_with_node_index(self): - """Test creating with node_index (node output).""" - point = ChildRegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) - assert point.region_index is None - assert point.node_index == 5 - assert point.output_index == 0 - - def test_immutability(self): - """Test that ChildRegionOutputInsertionPoint is immutable (frozen).""" - point = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - passed = False - try: - point.region_index = 2 - except AttributeError: - passed = True - assert passed, "ChildRegionOutputInsertionPoint should be immutable" - - def test_equality(self): - """Test equality comparison.""" - point1 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point2 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point3 = ChildRegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) - - assert point1 == point2 - assert point1 != point3 - - def test_hashable(self): - """Test that points can be used in sets and dicts.""" - point1 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point2 = ChildRegionOutputInsertionPoint(region_index=1, node_index=None, output_index=0) - point3 = ChildRegionOutputInsertionPoint(region_index=None, node_index=1, output_index=0) - - point_set = {point1, point2, point3} - assert len(point_set) == 2 # point1 and point2 are the same - - def test_serialization_region_index(self): - """Test serialization with region_index.""" - point = ChildRegionOutputInsertionPoint(region_index=3, node_index=None, output_index=2) - - data = point.to_dict() - assert data["region_index"] == 3 - assert data["node_index"] is None - assert data["output_index"] == 2 - - restored = ChildRegionOutputInsertionPoint.from_dict(data) - assert point == restored - - def test_serialization_node_index(self): - """Test serialization with node_index.""" - point = ChildRegionOutputInsertionPoint(region_index=None, node_index=7, output_index=1) - - data = point.to_dict() - assert data["region_index"] is None - assert data["node_index"] == 7 - assert data["output_index"] == 1 - - restored = ChildRegionOutputInsertionPoint.from_dict(data) - assert point == restored - - def test_string_representation(self): - """Test __str__ method.""" - point1 = ChildRegionOutputInsertionPoint(region_index=2, node_index=None, output_index=1) - s1 = str(point1) - assert "region" in s1.lower() - assert "2" in s1 - - point2 = ChildRegionOutputInsertionPoint(region_index=None, node_index=5, output_index=0) - s2 = str(point2) - assert "node" in s2.lower() - assert "5" in s2 - - -class TestChildRegionInputInsertionPoint(unittest.TestCase): - """Test ChildRegionInputInsertionPoint functionality.""" - - def test_creation(self): - """Test creating ChildRegionInputInsertionPoint.""" - point = ChildRegionInputInsertionPoint(region_index=3, input_index=1) - assert point.region_index == 3 - assert point.input_index == 1 - - def test_immutability(self): - """Test that ChildRegionInputInsertionPoint is immutable (frozen).""" - point = ChildRegionInputInsertionPoint(region_index=1, input_index=0) - passed = False - try: - point.region_index = 2 - except AttributeError: - passed = True - assert passed, "ChildRegionInputInsertionPoint should be immutable" - - def test_equality(self): - """Test equality comparison.""" - point1 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) - point2 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) - point3 = ChildRegionInputInsertionPoint(region_index=2, input_index=1) - + @pytest.mark.parametrize( + ("cls", "kwargs", "equal_kwargs", "diff_kwargs", "_", "__"), INSERTION_POINT_CASES + ) + def test_equality(self, cls, kwargs, equal_kwargs, diff_kwargs, _, __): + point1 = cls(**kwargs) + point2 = cls(**equal_kwargs) + point3 = cls(**diff_kwargs) assert point1 == point2 assert point1 != point3 - def test_hashable(self): - """Test that points can be used in sets and dicts.""" - point1 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) - point2 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) - point3 = ChildRegionInputInsertionPoint(region_index=2, input_index=0) - + @pytest.mark.parametrize( + ("cls", "kwargs", "equal_kwargs", "diff_kwargs", "_", "__"), INSERTION_POINT_CASES + ) + def test_hashable(self, cls, kwargs, equal_kwargs, diff_kwargs, _, __): + point1 = cls(**kwargs) + point2 = cls(**equal_kwargs) + point3 = cls(**diff_kwargs) point_set = {point1, point2, point3} - assert len(point_set) == 2 # point1 and point2 are the same - - def test_serialization(self): - """Test to_dict and from_dict.""" - point = ChildRegionInputInsertionPoint(region_index=5, input_index=2) + assert len(point_set) == 2 + @pytest.mark.parametrize(("cls", "kwargs", "_", "__", "___", "____"), INSERTION_POINT_CASES) + def test_serialization(self, cls, kwargs, _, __, ___, ____): + point = cls(**kwargs) data = point.to_dict() - assert data["region_index"] == 5 - assert data["input_index"] == 2 - - restored = ChildRegionInputInsertionPoint.from_dict(data) + for key, val in kwargs.items(): + assert data[key] == val + restored = cls.from_dict(data) assert point == restored - def test_string_representation(self): - """Test __str__ method.""" - point = ChildRegionInputInsertionPoint(region_index=2, input_index=1) - s = str(point) - assert "2" in s - assert "1" in s + @pytest.mark.parametrize( + ("cls", "kwargs", "_", "__", "___", "str_checks"), INSERTION_POINT_CASES + ) + def test_string_representation(self, cls, kwargs, _, __, ___, str_checks): + point = cls(**kwargs) + s = str(point).lower() + for check in str_checks: + assert check.lower() in s -class TestInsertionScheme(unittest.TestCase): +class TestInsertionScheme: """Test InsertionScheme functionality.""" def test_empty_scheme(self): """Test empty InsertionScheme.""" scheme = InsertionScheme() - assert scheme.is_empty assert len(scheme.node_inputs) == 0 assert len(scheme.child_region_inputs) == 0 assert len(scheme.region_outputs) == 0 assert not scheme.error - def test_scheme_with_node_inputs(self): - """Test scheme with node input insertion points.""" - scheme = InsertionScheme() - scheme.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] - - assert not scheme.is_empty - assert len(scheme.node_inputs) == 2 - - def test_scheme_with_region_outputs(self): - """Test scheme with region output insertion points.""" - scheme = InsertionScheme() - scheme.region_outputs = [ - ChildRegionOutputInsertionPoint(None, 0, 0), - ChildRegionOutputInsertionPoint(1, None, 0), - ] - - assert not scheme.is_empty - assert len(scheme.region_outputs) == 2 - - def test_scheme_with_composite_regions(self): - """Test scheme with composite region insertion points.""" + @pytest.mark.parametrize( + ("attr", "points"), + [ + ("node_inputs", [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)]), + ( + "region_outputs", + [ + ChildRegionOutputInsertionPoint(None, 0, 0), + ChildRegionOutputInsertionPoint(1, None, 0), + ], + ), + ( + "child_region_inputs", + [ChildRegionInputInsertionPoint(0, 0), ChildRegionInputInsertionPoint(1, 0)], + ), + ], + ) + def test_scheme_with_points_not_empty(self, attr, points): + """Test scheme with insertion points is not empty.""" scheme = InsertionScheme() - scheme.child_region_inputs = [ - ChildRegionInputInsertionPoint(0, 0), - ChildRegionInputInsertionPoint(1, 0), - ] - + setattr(scheme, attr, points) assert not scheme.is_empty - assert len(scheme.child_region_inputs) == 2 + assert len(getattr(scheme, attr)) == 2 def test_scheme_hash_empty(self): - """Test hash of empty scheme.""" - scheme1 = InsertionScheme() - scheme2 = InsertionScheme() - - assert scheme1.hash == scheme2.hash - - def test_scheme_hash_with_points(self): - """Test hash with insertion points.""" - scheme1 = InsertionScheme() - scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] - - scheme2 = InsertionScheme() - scheme2.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] - - scheme3 = InsertionScheme() - scheme3.node_inputs = [ - NodeInputInsertionPoint(0, 0), - NodeInputInsertionPoint(2, 0), # Different - ] - - assert scheme1.hash == scheme2.hash - assert scheme1.hash != scheme3.hash - - def test_scheme_hash_order_independent(self): - """Test that hash is independent of insertion point order.""" - scheme1 = InsertionScheme() - scheme1.node_inputs = [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)] - - scheme2 = InsertionScheme() - scheme2.node_inputs = [ - NodeInputInsertionPoint(1, 0), - NodeInputInsertionPoint(0, 0), # Reversed order - ] - - # Hash should be the same regardless of order - assert scheme1.hash == scheme2.hash - - def test_serialization_empty(self): - """Test serialization of empty scheme.""" + """Test hash of empty schemes are equal.""" + assert InsertionScheme().hash == InsertionScheme().hash + + def test_scheme_hash_equality(self): + """Test hash with same/different insertion points.""" + + def make_scheme(*node_indices): + s = InsertionScheme() + s.node_inputs = [NodeInputInsertionPoint(i, 0) for i in node_indices] + return s + + assert make_scheme(0, 1).hash == make_scheme(0, 1).hash + assert make_scheme(0, 1).hash == make_scheme(1, 0).hash # order independent + assert make_scheme(0, 1).hash != make_scheme(0, 2).hash + + @pytest.mark.parametrize( + ("error", "latency"), + [ + (False, float("inf")), # empty + (False, 12.5), # full + (True, float("inf")), # with error + ], + ) + def test_serialization_roundtrip(self, error, latency): + """Test serialization roundtrip.""" scheme = InsertionScheme() + scheme.error = error + scheme.latency_ms = latency - data = scheme.to_dict() - restored = InsertionScheme.from_dict(data) - - assert restored.is_empty - assert restored.latency_ms == float("inf") - assert not restored.error + if latency != float("inf") or error: # add points for non-empty cases + scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] + scheme.region_outputs = [ChildRegionOutputInsertionPoint(None, 0, 0)] - def test_serialization_full(self): - """Test serialization with all types of insertion points.""" - scheme = InsertionScheme() - scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] - scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] - scheme.region_outputs = [ChildRegionOutputInsertionPoint(None, 0, 0)] - scheme.latency_ms = 12.5 - scheme.error = False - - data = scheme.to_dict() - restored = InsertionScheme.from_dict(data) - - assert len(restored.node_inputs) == 1 - assert len(restored.child_region_inputs) == 1 - assert len(restored.region_outputs) == 1 - assert restored.latency_ms == 12.5 - assert not restored.error - - def test_serialization_with_error(self): - """Test serialization with error flag.""" - scheme = InsertionScheme() - scheme.error = True - scheme.latency_ms = float("inf") - - data = scheme.to_dict() - restored = InsertionScheme.from_dict(data) + restored = InsertionScheme.from_dict(scheme.to_dict()) - assert restored.error - assert restored.latency_ms == float("inf") + assert restored.error == error + assert restored.latency_ms == latency + if not scheme.is_empty: + assert len(restored.node_inputs) == len(scheme.node_inputs) + assert len(restored.child_region_inputs) == len(scheme.child_region_inputs) + assert len(restored.region_outputs) == len(scheme.region_outputs) def _create_mock_tensor(name: str, dtype=np.float32, shape=None): @@ -387,17 +247,29 @@ def _create_mock_node(op: str, inputs: list, outputs: list, name: str = ""): return node +def _create_region(region_id=1, level=0, region_type=RegionType.LEAF, nodes=None): + """Create a region with the specified properties. + + Args: + region_id: ID for the region + level: Hierarchy level (0 for LEAF, 1+ for COMPOSITE/ROOT) + region_type: Type of region (LEAF, COMPOSITE, or ROOT) + nodes: Optional list/set of node indices to add to the region + + Returns: + Region with specified properties and nodes + """ + region = Region(region_id=region_id, level=level, region_type=region_type) + if nodes: + region.nodes.update(nodes) + return region + + def _create_simple_graph(): """Create a mock graph with Conv -> BatchNorm -> Relu -> MaxPool pattern. Graph structure: input -> Conv -> conv_out -> BatchNorm -> bn_out -> Relu -> relu_out -> MaxPool -> pool_out - - Node indices: - 0: Conv - 1: BatchNormalization - 2: Relu - 3: MaxPool """ # Create tensors with realistic shapes input_tensor = _create_mock_tensor("input", np.float32, [1, 3, 224, 224]) @@ -476,13 +348,6 @@ def _create_residual_graph(): │ ▼ Relu2 -> output - - Node indices: - 0: Conv1 - 1: Relu1 - 2: Conv2 - 3: Add - 4: Relu2 """ # Create tensors input_tensor = _create_mock_tensor("input", np.float32, [1, 64, 56, 56]) @@ -537,215 +402,113 @@ def _create_residual_graph(): return graph, tensors -class TestSkipInvalidInsertionPoints(unittest.TestCase): +class TestSkipInvalidInsertionPoints: """Test skip_invalid_insertion_points function.""" - def test_skip_bool_operations(self): - """Test that boolean operations are skipped.""" + @pytest.mark.parametrize( + ("op", "should_skip"), + [ + ("Equal", True), # bool op + ("Shape", True), # shape op + ("MatMul", False), # normal op + ("Add", False), # normal op + ], + ) + def test_skip_by_op_type(self, op, should_skip): graph, _ = _create_simple_graph() - - # Create a node with boolean operation - bool_tensor = _create_mock_tensor("bool_input", np.float32) - bool_node = _create_mock_node("Equal", [bool_tensor], []) - - result = skip_invalid_insertion_points(graph, "bool_input", bool_node) - assert result is True - - def test_skip_shape_operations(self): - """Test that shape operations are skipped.""" + tensor = _create_mock_tensor("test_input", np.float32, [1, 64, 32, 32]) + node = _create_mock_node(op, [tensor], []) + assert skip_invalid_insertion_points(graph, "test_input", node) is should_skip + + @pytest.mark.parametrize( + ("dtype", "shape", "should_skip"), + [ + (np.int32, [1, 64, 32, 32], True), # non-float + (np.float32, [1], True), # small tensor + (np.float32, [1, 64, 32, 32], False), # large float - OK + ], + ) + def test_skip_by_tensor_properties(self, dtype, shape, should_skip): graph, _ = _create_simple_graph() - - shape_tensor = _create_mock_tensor("shape_input", np.float32) - shape_node = _create_mock_node("Shape", [shape_tensor], []) - - result = skip_invalid_insertion_points(graph, "shape_input", shape_node) - assert result is True + tensor = _create_mock_tensor("test", dtype, shape) + node = _create_mock_node("Add", [tensor], []) + assert skip_invalid_insertion_points(graph, "test", node) is should_skip def test_skip_conv_weight_input(self): - """Test that Conv weight inputs (index >= 1) are skipped.""" - graph, tensors = _create_simple_graph() - conv_node = graph.nodes[0] - - # Weight is at input index 1 - result = skip_invalid_insertion_points(graph, "conv_weight", conv_node) - assert result is True - - def test_allow_conv_data_input(self): - """Test that Conv data input (index 0) is allowed.""" - graph, tensors = _create_simple_graph() - - # Create a MatMul node that consumes the input tensor (not Conv-related skip) - input_tensor = _create_mock_tensor("matmul_input", np.float32, [1, 3, 224, 224]) - matmul_node = _create_mock_node("MatMul", [input_tensor], []) - - result = skip_invalid_insertion_points(graph, "matmul_input", matmul_node) - assert result is False - - def test_skip_non_float_tensors(self): - """Test that non-floating-point tensors are skipped.""" + """Conv weight inputs (index >= 1) are skipped.""" graph, _ = _create_simple_graph() - - # Create int tensor - int_tensor = _create_mock_tensor("int_input", np.int32) - node = _create_mock_node("Add", [int_tensor], []) - - result = skip_invalid_insertion_points(graph, "int_input", node) + result = skip_invalid_insertion_points(graph, "conv_weight", graph.nodes[0]) assert result is True - def test_skip_small_tensors(self): - """Test that small tensors (< 8 elements) are skipped.""" + def test_skip_bn_non_data_inputs(self): + """BatchNormalization non-data inputs are skipped.""" graph, _ = _create_simple_graph() - - # Create small tensor (scalar) - small_tensor = _create_mock_tensor("small", np.float32, [1]) - node = _create_mock_node("Add", [small_tensor], []) - - result = skip_invalid_insertion_points(graph, "small", node) + result = skip_invalid_insertion_points(graph, "bn_scale", graph.nodes[1]) assert result is True - def test_allow_large_float_tensors(self): - """Test that large floating-point tensors are allowed.""" + def test_skip_conv_bn_relu_fusion(self): + """Conv->BN->Relu fusion patterns are skipped at intermediate points.""" graph, _ = _create_simple_graph() - - # Create large float tensor - large_tensor = _create_mock_tensor("large", np.float32, [1, 64, 32, 32]) - node = _create_mock_node("Add", [large_tensor], []) - - result = skip_invalid_insertion_points(graph, "large", node) - assert result is False - - def test_skip_bn_non_data_inputs(self): - """Test that BatchNormalization non-data inputs are skipped.""" - graph, tensors = _create_simple_graph() - bn_node = graph.nodes[1] # BatchNormalization node - - # Scale is at input index 1, should be skipped - result = skip_invalid_insertion_points(graph, "bn_scale", bn_node) + result = skip_invalid_insertion_points(graph, "bn_out", graph.nodes[2]) assert result is True def test_with_region(self): - """Test skip_invalid_insertion_points with a Region containing multiple nodes.""" - graph, tensors = _create_simple_graph() - - # Create a region containing Conv and BatchNorm nodes - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv node - region.nodes.add(1) # BatchNorm node + """Test with a Region containing multiple nodes.""" + graph, _ = _create_simple_graph() + region = _create_region(nodes=[0, 1]) - # Create a shape operation node and add to graph shape_tensor = _create_mock_tensor("shape_input", np.float32) shape_node = _create_mock_node("Shape", [shape_tensor], []) graph.nodes.append(shape_node) - region.nodes.add(4) # Add the shape node to region - - result = skip_invalid_insertion_points(graph, "shape_input", region) - assert result is True + region.nodes.add(4) - def test_skip_conv_bn_relu_fusion(self): - """Test that Conv->BN->Relu fusion patterns are skipped at intermediate points.""" - graph, tensors = _create_simple_graph() - relu_node = graph.nodes[2] # Relu node - - # Relu input (bn_out) should be skipped when preceded by Conv->BN - result = skip_invalid_insertion_points(graph, "bn_out", relu_node) - assert result is True - - def test_residual_block_add_inputs(self): - """Test insertion points in a residual block with skip connection.""" - graph, tensors = _create_residual_graph() - add_node = graph.nodes[3] # Add node + assert skip_invalid_insertion_points(graph, "shape_input", region) is True - # Add's first input (conv2_out) should be allowed - result = skip_invalid_insertion_points(graph, "conv2_out", add_node) - assert result is False + def test_residual_block_add_inputs_allowed(self): + """Add node inputs in residual blocks should be allowed.""" + graph, _ = _create_residual_graph() + add_node = graph.nodes[3] - # Add's second input (skip connection input) should also be allowed - result = skip_invalid_insertion_points(graph, "input", add_node) - assert result is False + assert skip_invalid_insertion_points(graph, "conv2_out", add_node) is False + assert skip_invalid_insertion_points(graph, "input", add_node) is False -class TestHasQuantizableOperations(unittest.TestCase): +class TestHasQuantizableOperations: """Test has_quantizable_operations function.""" - def test_leaf_with_conv(self): - """Test LEAF region with Conv operation.""" - graph, _ = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv node - - result = has_quantizable_operations(region, graph) - assert result is True - - def test_leaf_with_maxpool(self): - """Test LEAF region with MaxPool (a major quantizable op).""" - graph, _ = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(3) # MaxPool node - - result = has_quantizable_operations(region, graph) - assert result is True - - def test_leaf_with_relu_only(self): - """Test LEAF region with only Relu.""" - graph, _ = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(2) # Relu node only (index 2 in new graph) - - result = has_quantizable_operations(region, graph) - assert result is True # Relu is in MAJOR_QUANTIZABLE_OPERATIONS - - def test_leaf_with_conv_bn_relu(self): - """Test LEAF region with Conv->BN->Relu pattern.""" - graph, _ = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv - region.nodes.add(1) # BatchNorm - region.nodes.add(2) # Relu - - result = has_quantizable_operations(region, graph) - assert result is True + @pytest.mark.parametrize( + ("nodes", "graph_fn", "expected"), + [ + ({0}, _create_simple_graph, True), # Conv + ({3}, _create_simple_graph, True), # MaxPool + ({2}, _create_simple_graph, True), # Relu + ({0, 1, 2}, _create_simple_graph, True), # Conv->BN->Relu + ({3}, _create_residual_graph, True), # Add in residual + ], + ) + def test_leaf_with_quantizable_ops(self, nodes, graph_fn, expected): + """Test LEAF region with various quantizable operations.""" + graph, _ = graph_fn() + region = _create_region(nodes=nodes) + assert has_quantizable_operations(region, graph) is expected def test_leaf_without_quantizable_ops(self): """Test LEAF region without major quantizable operations.""" - # Create graph with only shape operations shape_tensor = _create_mock_tensor("input", np.float32) output_tensor = _create_mock_tensor("output", np.float32) shape_node = _create_mock_node("Shape", [shape_tensor], [output_tensor]) transpose_node = _create_mock_node("Transpose", [output_tensor], []) - graph = MagicMock(spec=gs.Graph) graph.nodes = [shape_node, transpose_node] + region = _create_region(nodes={0, 1}) - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) - region.nodes.add(1) - - result = has_quantizable_operations(region, graph) - assert result is False + assert has_quantizable_operations(region, graph) is False def test_composite_region_always_true(self): """Test that COMPOSITE regions always return True.""" graph, _ = _create_simple_graph() - - region = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - # Don't add any nodes - COMPOSITE regions assume children have quantizable ops - - result = has_quantizable_operations(region, graph) - assert result is True - - def test_residual_block_has_quantizable_ops(self): - """Test residual block with Add operation.""" - graph, _ = _create_residual_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(3) # Add node - - result = has_quantizable_operations(region, graph) - assert result is True # Add is in MAJOR_QUANTIZABLE_OPERATIONS + region = _create_region(level=1, region_type=RegionType.COMPOSITE) + assert has_quantizable_operations(region, graph) is True class TestResolveRegionIOInsertionPoints(unittest.TestCase): @@ -757,10 +520,7 @@ def test_resolve_with_region(self): # Set up tensor_users_map: conv_out is consumed by BatchNorm (node 1) graph.tensor_users_map = get_tensor_consumer_node_indices(graph) - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(2) # Relu node - + region = _create_region(nodes=[2]) # Relu node result = resolve_region_io_insertion_points(region, graph, "relu_out") assert len(result) >= 1 @@ -772,7 +532,6 @@ def test_resolve_without_region(self): # Set up tensor_users_map: bn_out is consumed by Relu (node 2) graph.tensor_users_map = get_tensor_consumer_node_indices(graph) - result = resolve_region_io_insertion_points(None, graph, "relu_out") assert len(result) == 1 @@ -785,7 +544,6 @@ def test_resolve_tensor_not_found(self): """Test resolving a tensor that has no users.""" graph, _ = _create_simple_graph() graph.tensor_users_map = {} - result = resolve_region_io_insertion_points(None, graph, "nonexistent") assert len(result) == 0 @@ -796,7 +554,6 @@ def test_resolve_residual_skip_connection(self): # Input tensor is used by Conv1 (node 0) and Add (node 3) graph.tensor_users_map = {"input": [0, 3]} - result = resolve_region_io_insertion_points(None, graph, "input") # Should find both consumers @@ -812,8 +569,7 @@ def test_resolve_with_multiple_consumers(self): # relu1_out feeds conv2 (node 2) graph.tensor_users_map = {"relu1_out": [2]} - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(2) # Conv2 + region = _create_region(nodes=[2]) # Conv2 result = resolve_region_io_insertion_points(region, graph, "relu1_out") @@ -938,22 +694,16 @@ def test_no_merge_residual_partial(self): assert ip.node_index == 0 # Still node-specific -class TestNodeInputInsertionPointResolve(unittest.TestCase): - """Test NodeInputInsertionPoint.resolve() method.""" +class TestNodeInputInsertionPointMethods(unittest.TestCase): + """Test NodeInputInsertionPoint.resolve() and collect_from_region() methods.""" def test_resolve_simple(self): """Test resolving a simple node input for Conv->BN->Relu->Pool.""" graph, tensors = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv node - region.nodes.add(1) # BatchNorm node - region.nodes.add(2) # Relu node - region.nodes.add(3) # MaxPool node + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool # Create insertion point for first input of first node (Conv) ip = NodeInputInsertionPoint(node_index=0, input_index=0) - result = ip.resolve(region, graph) assert len(result) >= 1 @@ -962,13 +712,10 @@ def test_resolve_simple(self): def test_resolve_conv_includes_weight(self): """Test that resolving Conv input also includes weight.""" graph, tensors = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv node + region = _create_region(nodes=[0]) # Conv node # Create insertion point for first input of Conv (should also add weight) ip = NodeInputInsertionPoint(node_index=0, input_index=0) - result = ip.resolve(region, graph) # Should include both data input and weight @@ -980,15 +727,10 @@ def test_resolve_conv_includes_weight(self): def test_resolve_relu_input(self): """Test resolving Relu input in the middle of the chain.""" graph, tensors = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv - region.nodes.add(1) # BatchNorm - region.nodes.add(2) # Relu + region = _create_region(nodes=[0, 1, 2]) # Conv, BatchNorm, Relu # Relu is at local index 2, input 0 is bn_out ip = NodeInputInsertionPoint(node_index=2, input_index=0) - result = ip.resolve(region, graph) assert len(result) == 1 @@ -998,15 +740,10 @@ def test_resolve_relu_input(self): def test_resolve_residual_conv_input(self): """Test resolving Conv input in residual block.""" graph, tensors = _create_residual_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv1 - region.nodes.add(1) # Relu1 - region.nodes.add(2) # Conv2 + region = _create_region(nodes=[0, 1, 2]) # Conv1, Relu1, Conv2 # Conv2 is at local index 2, input 0 is relu1_out ip = NodeInputInsertionPoint(node_index=2, input_index=0) - result = ip.resolve(region, graph) # Conv includes both data and weight @@ -1015,9 +752,34 @@ def test_resolve_residual_conv_input(self): assert "relu1_out" in tensor_names assert "conv2_weight" in tensor_names + def test_collect_valid_inputs(self): + """Test collecting valid node input insertion points from Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected some insertion points + assert len(result) >= 1 + # All should be NodeInputInsertionPoint + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + def test_collect_from_residual_block(self): + """Test collecting from residual block with skip connection.""" + graph, tensors = _create_residual_graph() + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected insertion points from Conv1, Add inputs, etc. + assert len(result) >= 1 + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + # Check that we have insertion points for different nodes + node_indices = {ip.node_index for ip in result} + assert len(node_indices) >= 1 # At least one node has valid inputs + -class TestChildRegionInputInsertionPointResolve(unittest.TestCase): - """Test ChildRegionInputInsertionPoint.resolve() method.""" +class TestChildRegionInputInsertionPointMethods(unittest.TestCase): + """Test ChildRegionInputInsertionPoint.resolve() and collect_from_region() methods.""" def test_resolve_composite_region(self): """Test resolving child region input in COMPOSITE region.""" @@ -1025,16 +787,11 @@ def test_resolve_composite_region(self): graph.tensor_users_map = {"input": [0]} # Create parent (COMPOSITE) with child (LEAF) containing Conv->BN->Relu - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu child.inputs = ["input"] - child.nodes.add(0) # Conv - child.nodes.add(1) # BatchNorm - child.nodes.add(2) # Relu parent.add_child(child) - ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) - result = ip.resolve(parent, graph) assert len(result) >= 1 @@ -1043,14 +800,9 @@ def test_resolve_composite_region(self): def test_resolve_leaf_returns_empty(self): """Test that LEAF regions return empty set.""" graph, _ = _create_simple_graph() - - leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) - leaf.nodes.add(0) - + leaf = _create_region(nodes=[0]) ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) - result = ip.resolve(leaf, graph) - assert len(result) == 0 def test_resolve_multiple_children(self): @@ -1060,18 +812,15 @@ def test_resolve_multiple_children(self): graph.tensor_users_map = get_tensor_consumer_node_indices(graph) # Create parent with two child regions - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) # First child: Conv1 (consumes "input") - child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child1 = _create_region(region_id=2, nodes=[0]) # Conv1 child1.inputs = ["input"] - child1.nodes.add(0) # Conv1 # Second child: Relu1 (consumes "relu1_out") - child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + child2 = _create_region(region_id=3, nodes=[2]) # Relu1 child2.inputs = ["relu1_out"] - child2.nodes.add(2) # Relu1 - parent.add_child(child1) parent.add_child(child2) @@ -1089,27 +838,53 @@ def test_resolve_multiple_children(self): assert len(result2) >= 1 assert any(rip.tensor_name == "relu1_out" for rip in result2) + def test_collect_from_composite(self): + """Test collecting from COMPOSITE region with children.""" + graph, tensors = _create_simple_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.inputs = ["input"] + parent.add_child(child) + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + # Should find the child's input + assert len(result) >= 0 # May be filtered by skip_invalid_insertion_points + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + def test_collect_from_leaf_returns_empty(self): + """Test that LEAF regions return empty list.""" + graph, _ = _create_simple_graph() + leaf = _create_region(nodes=[0]) + result = ChildRegionInputInsertionPoint.collect_from_region(leaf, graph) + assert len(result) == 0 + + def test_collect_from_composite_with_multiple_children(self): + """Test collecting from COMPOSITE with multiple child regions.""" + graph, tensors = _create_residual_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = _create_region(region_id=2, nodes=[0, 1]) # Conv1, Relu1 + child1.inputs = ["input"] + child2 = _create_region(region_id=3, nodes=[2, 3]) # Conv2, Add + child2.inputs = ["relu1_out", "input"] # Two inputs including skip connection + parent.add_child(child1) + parent.add_child(child2) + + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + # Should find inputs from both children + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + -class TestChildRegionOutputInsertionPointResolve(unittest.TestCase): - """Test ChildRegionOutputInsertionPoint.resolve() method.""" +class TestChildRegionOutputInsertionPointMethods(unittest.TestCase): + """Test ChildRegionOutputInsertionPoint.resolve() and collect_from_region() methods.""" def test_resolve_node_output(self): """Test resolving a node output.""" graph, tensors = _create_simple_graph() graph.tensor_users_map = get_tensor_consumer_node_indices(graph) - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv - region.nodes.add(1) # BatchNorm - region.nodes.add(2) # Relu - region.nodes.add(3) # MaxPool + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool region.outputs = ["pool_out"] - # Output of last node (MaxPool) ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=2, output_index=0) - result = ip.resolve(region, graph) - assert len(result) >= 1 assert any(rip.tensor_name == "relu_out" for rip in result) @@ -1117,19 +892,12 @@ def test_resolve_child_region_output(self): """Test resolving a child region output.""" graph, tensors = _create_simple_graph() graph.tensor_users_map = {"relu_out": [3]} - - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu child.outputs = ["relu_out"] - child.nodes.add(0) # Conv - child.nodes.add(1) # BatchNorm - child.nodes.add(2) # Relu parent.add_child(child) - ip = ChildRegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) - result = ip.resolve(parent, graph) - assert len(result) >= 1 assert any(rip.tensor_name == "relu_out" for rip in result) @@ -1137,137 +905,19 @@ def test_resolve_residual_add_output(self): """Test resolving Add output in residual block.""" graph, tensors = _create_residual_graph() graph.tensor_users_map = {"add_out": [4]} - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv1 - region.nodes.add(1) # Relu1 - region.nodes.add(2) # Conv2 - region.nodes.add(3) # Add - region.nodes.add(4) # Relu2 + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 region.outputs = ["add_out"] - # Add is at local index 3, output 0 ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=3, output_index=0) - result = ip.resolve(region, graph) - assert len(result) >= 1 assert any(rip.tensor_name == "add_out" for rip in result) - -class TestNodeInputInsertionPointCollectFrom(unittest.TestCase): - """Test NodeInputInsertionPoint.collect_from_region() method.""" - - def test_collect_valid_inputs(self): - """Test collecting valid node input insertion points from Conv->BN->Relu->Pool.""" - graph, tensors = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv - region.nodes.add(1) # BatchNorm - region.nodes.add(2) # Relu - region.nodes.add(3) # MaxPool - - result = NodeInputInsertionPoint.collect_from_region(region, graph) - - # Should have collected some insertion points - assert len(result) >= 1 - # All should be NodeInputInsertionPoint - assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) - - def test_collect_from_residual_block(self): - """Test collecting from residual block with skip connection.""" - graph, tensors = _create_residual_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv1 - region.nodes.add(1) # Relu1 - region.nodes.add(2) # Conv2 - region.nodes.add(3) # Add - region.nodes.add(4) # Relu2 - - result = NodeInputInsertionPoint.collect_from_region(region, graph) - - # Should have collected insertion points from Conv1, Add inputs, etc. - assert len(result) >= 1 - assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) - - # Check that we have insertion points for different nodes - node_indices = {ip.node_index for ip in result} - assert len(node_indices) >= 1 # At least one node has valid inputs - - -class TestChildRegionInputInsertionPointCollectFrom(unittest.TestCase): - """Test ChildRegionInputInsertionPoint.collect_from_region() method.""" - - def test_collect_from_composite(self): - """Test collecting from COMPOSITE region with children.""" - graph, tensors = _create_simple_graph() - - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - child = Region(region_id=2, level=0, region_type=RegionType.LEAF) - child.inputs = ["input"] - child.nodes.add(0) # Conv - child.nodes.add(1) # BatchNorm - child.nodes.add(2) # Relu - parent.add_child(child) - - result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) - - # Should find the child's input - assert len(result) >= 0 # May be filtered by skip_invalid_insertion_points - assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) - - def test_collect_from_leaf_returns_empty(self): - """Test that LEAF regions return empty list.""" - graph, _ = _create_simple_graph() - - leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) - leaf.nodes.add(0) - - result = ChildRegionInputInsertionPoint.collect_from_region(leaf, graph) - - assert len(result) == 0 - - def test_collect_from_composite_with_multiple_children(self): - """Test collecting from COMPOSITE with multiple child regions.""" - graph, tensors = _create_residual_graph() - - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - - child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) - child1.inputs = ["input"] - child1.nodes.add(0) # Conv1 - child1.nodes.add(1) # Relu1 - - child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) - child2.inputs = ["relu1_out", "input"] # Two inputs including skip connection - child2.nodes.add(2) # Conv2 - child2.nodes.add(3) # Add - - parent.add_child(child1) - parent.add_child(child2) - - result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) - - # Should find inputs from both children - assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) - - -class TestChildRegionOutputInsertionPointCollectFrom(unittest.TestCase): - """Test ChildRegionOutputInsertionPoint.collect_from_region() method.""" - def test_collect_node_outputs(self): """Test collecting node output insertion points.""" graph, tensors = _create_simple_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv - region.nodes.add(1) # BatchNorm - region.nodes.add(2) # Relu - region.nodes.add(3) # MaxPool + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool region.outputs = ["pool_out"] # Only pool_out is a region output - result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) # Should find the node output that matches region output @@ -1277,16 +927,11 @@ def test_collect_node_outputs(self): def test_collect_child_region_outputs(self): """Test collecting child region output insertion points.""" graph, tensors = _create_simple_graph() - - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu child.outputs = ["relu_out"] - child.nodes.add(0) # Conv - child.nodes.add(1) # BatchNorm - child.nodes.add(2) # Relu parent.add_child(child) parent.outputs = ["relu_out"] # Child output is also parent output - result = ChildRegionOutputInsertionPoint.collect_from_region(parent, graph) # Should find the child region output @@ -1295,15 +940,8 @@ def test_collect_child_region_outputs(self): def test_collect_residual_block_outputs(self): """Test collecting outputs from residual block.""" graph, tensors = _create_residual_graph() - - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - region.nodes.add(0) # Conv1 - region.nodes.add(1) # Relu1 - region.nodes.add(2) # Conv2 - region.nodes.add(3) # Add - region.nodes.add(4) # Relu2 + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 region.outputs = ["output"] # Final output - result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) # Should find the output diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py index 297d5a2be..a27b1c98c 100644 --- a/tests/unit/onnx/quantization/autotune/test_region.py +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,154 +13,100 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" -Tests for the Region class in the autotuner. - -Tests region creation, hierarchy, and boundary management. -""" +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the Region class in the autotuner.""" -import unittest +import pytest from modelopt.onnx.quantization.autotune.common import Region, RegionType -class TestRegion(unittest.TestCase): - """Test Region class functionality.""" - - def test_region_creation(self): - """Test creating regions of all types.""" - test_cases = [ - {"region_id": 1, "level": 0, "region_type": RegionType.LEAF}, - {"region_id": 2, "level": 1, "region_type": RegionType.COMPOSITE}, - {"region_id": 0, "level": 2, "region_type": RegionType.ROOT}, - ] - - for params in test_cases: - with self.subTest(**params): - region = Region(**params) - assert region.id == params["region_id"] - assert region.level == params["level"] - assert region.type == params["region_type"] - - def test_parent_child_relationship(self): - """Test parent-child relationships.""" - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) - child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) - - parent.add_child(child1) - parent.add_child(child2) - - assert len(parent.get_children()) == 2 - assert child1.parent == parent - assert child2.parent == parent - assert child1 in parent.get_children() - assert child2 in parent.get_children() - - def test_add_nodes(self): - """Test adding nodes to a region.""" - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - - region.nodes.add(0) - region.nodes.add(1) - region.nodes.add(2) - - assert len(region.nodes) == 3 - assert 0 in region.get_nodes() - assert 1 in region.get_nodes() - assert 2 in region.get_nodes() - - def test_input_output_tensors(self): - """Test setting input and output tensors.""" - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - - # Directly assign to inputs/outputs attributes - region.inputs = ["input_tensor_1", "input_tensor_2"] - region.outputs = ["output_tensor_1"] - - assert len(region.inputs) == 2 - assert len(region.outputs) == 1 - assert "input_tensor_1" in region.inputs - assert "output_tensor_1" in region.outputs - - def test_region_size_recursive(self): - """Test recursive size calculation.""" - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) - child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) - - # Add nodes to children - child1.nodes.add(0) - child1.nodes.add(1) - child2.nodes.add(2) - child2.nodes.add(3) - child2.nodes.add(4) - - # Add children to parent - parent.add_child(child1) - parent.add_child(child2) - - # Parent itself might have direct nodes - parent.nodes.add(5) - - # Recursive count should include all nodes - assert len(parent.get_region_nodes_and_descendants()) == 6 - - def test_metadata(self): - """Test region metadata storage.""" - region = Region(region_id=1, level=0, region_type=RegionType.LEAF) - - region.metadata["pattern"] = "Conv->Relu" - region.metadata["quantizable"] = "true" - - assert region.metadata["pattern"] == "Conv->Relu" - assert region.metadata["quantizable"] == "true" - - def test_region_type_checks(self): - """Test checking region types (LEAF and COMPOSITE).""" - leaf = Region(region_id=1, level=0, region_type=RegionType.LEAF) - composite = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) - - assert leaf.type == RegionType.LEAF - assert leaf.type != RegionType.COMPOSITE - assert composite.type == RegionType.COMPOSITE - assert composite.type != RegionType.LEAF - - def test_hierarchical_structure(self): - """Test complex hierarchical structure.""" - root = Region(region_id=0, level=2, region_type=RegionType.ROOT) - composite1 = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - composite2 = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) - leaf1 = Region(region_id=3, level=0, region_type=RegionType.LEAF) - leaf2 = Region(region_id=4, level=0, region_type=RegionType.LEAF) - leaf3 = Region(region_id=5, level=0, region_type=RegionType.LEAF) - - # Build hierarchy - root.add_child(composite1) - root.add_child(composite2) - composite1.add_child(leaf1) - composite1.add_child(leaf2) - composite2.add_child(leaf3) - - # Add some nodes - leaf1.nodes.add(0) - leaf2.nodes.add(1) - leaf3.nodes.add(2) - - # Verify structure - assert len(root.get_children()) == 2 - assert len(composite1.get_children()) == 2 - assert len(composite2.get_children()) == 1 - assert len(root.get_region_nodes_and_descendants()) == 3 - - def test_remove_child(self): - """Test removing a child region.""" - parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) - child = Region(region_id=2, level=0, region_type=RegionType.LEAF) - - parent.add_child(child) - assert len(parent.get_children()) == 1 - - parent.remove_child(child) - assert len(parent.get_children()) == 0 - assert child.parent is None +@pytest.fixture +def leaf(): + return Region(region_id=1, level=0, region_type=RegionType.LEAF) + + +@pytest.fixture +def parent_with_children(): + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + parent.add_child(child1) + parent.add_child(child2) + return parent, child1, child2 + + +@pytest.mark.parametrize( + ("region_id", "level", "region_type"), + [ + (1, 0, RegionType.LEAF), + (2, 1, RegionType.COMPOSITE), + (0, 2, RegionType.ROOT), + ], +) +def test_region_creation(region_id, level, region_type): + region = Region(region_id=region_id, level=level, region_type=region_type) + assert (region.id, region.level, region.type) == (region_id, level, region_type) + + +def test_parent_child_relationship(parent_with_children): + parent, child1, child2 = parent_with_children + assert parent.get_children() == [child1, child2] + assert child1.parent == child2.parent == parent + + +def test_add_and_get_nodes(leaf): + leaf.nodes.update([0, 1, 2]) + assert set(leaf.get_nodes()) == {0, 1, 2} + + +def test_input_output_tensors(leaf): + leaf.inputs = ["in1", "in2"] + leaf.outputs = ["out1"] + assert leaf.inputs == ["in1", "in2"] + assert leaf.outputs == ["out1"] + + +def test_region_size_recursive(parent_with_children): + parent, child1, child2 = parent_with_children + child1.nodes.update([0, 1]) + child2.nodes.update([2, 3, 4]) + parent.nodes.add(5) + assert len(parent.get_region_nodes_and_descendants()) == 6 + + +def test_metadata(leaf): + leaf.metadata.update({"pattern": "Conv->Relu", "quantizable": "true"}) + assert leaf.metadata == {"pattern": "Conv->Relu", "quantizable": "true"} + + +def test_hierarchical_structure(): + root = Region(region_id=0, level=2, region_type=RegionType.ROOT) + comp1 = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + comp2 = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + leaves = [Region(region_id=i, level=0, region_type=RegionType.LEAF) for i in range(3, 6)] + root.add_child(comp1) + root.add_child(comp2) + comp1.add_child(leaves[0]) + comp1.add_child(leaves[1]) + comp2.add_child(leaves[2]) + for i, leaf in enumerate(leaves): + leaf.nodes.add(i) + assert len(root.get_children()) == 2 + assert len(comp1.get_children()) == 2 + assert len(comp2.get_children()) == 1 + assert len(root.get_region_nodes_and_descendants()) == 3 + + +def test_remove_child(): + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + parent.add_child(child) + parent.remove_child(child) + assert parent.get_children() == [] + assert child.parent is None