diff --git a/modelopt/onnx/op_types.py b/modelopt/onnx/op_types.py index cc94a221f..30c14e90e 100644 --- a/modelopt/onnx/op_types.py +++ b/modelopt/onnx/op_types.py @@ -96,7 +96,7 @@ def is_fusible_scaling_op(op_type: str): ] -def get_copy_ops(): +def get_copy_ops() -> list[str]: """Returns list of copy operators.""" return [ "Flatten", @@ -303,3 +303,67 @@ def is_data_dependent_shape_op(op_type: str): "NonZero", "RoiAlign", ] + + +def get_bool_ops(): + """Returns set of bool operations.""" + return { + "Not", + "And", + "Or", + "Xor", + } + + +def get_bitwise_ops(): + """Returns set of bitwise operations.""" + return { + "BitwiseAnd", + "BitwiseOr", + "BitwiseXor", + "BitShift", + } + + +def get_value_check_ops(): + """Returns set of value checking operations.""" + return { + "IsNaN", + "IsInf", + "Sign", + "Abs", + } + + +def get_comparison_ops(): + """Returns set of comparison operations.""" + return { + "Equal", + "Greater", + "GreaterOrEqual", + "Less", + "LessOrEqual", + } + + +def get_conditional_ops(): + """Returns set of conditional operations.""" + return { + "Where", + } + + +def get_aggregation_ops(): + """Returns set of aggregation operations.""" + return { + "All", + "Any", + } + + +def get_set_ops(): + """Returns set of set/search operations.""" + return { + "Unique", + "NonZero", + } diff --git a/modelopt/onnx/quantization/autotune/common.py b/modelopt/onnx/quantization/autotune/common.py new file mode 100644 index 000000000..a8929315a --- /dev/null +++ b/modelopt/onnx/quantization/autotune/common.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common data structures and types for the QDQ Autotuner.""" + +import hashlib +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from modelopt.onnx.logging_config import logger +from modelopt.onnx.quantization.autotune.insertion_points import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + NodeInputInsertionPoint, +) + + +class AutotunerError(Exception): + """Base exception for autotuner-related errors.""" + + +class AutotunerNotInitializedError(AutotunerError): + """Exception raised when autotuner is used without initialization.""" + + +class InvalidSchemeError(AutotunerError): + """Exception raised when an invalid scheme is referenced.""" + + +class RegionType(Enum): + """Region type enumeration for hierarchical graph structure. + + - LEAF: Atomic region containing direct nodes with no child regions + - COMPOSITE: Hierarchical region containing child regions (and optionally direct nodes) + - ROOT: Top-level region encompassing the entire computation graph + """ + + LEAF = "LEAF" + COMPOSITE = "COMPOSITE" + ROOT = "ROOT" + + +class Region: + """A subgraph region in an ONNX graph, used as the unit for Q/DQ insertion. + + Regions form a hierarchy: ROOT contains the entire graph, COMPOSITE regions + contain child regions, and LEAF regions contain only nodes. Each region tracks + its direct nodes, input/output tensors, and a pattern signature for matching + regions with identical structure. + """ + + def __init__(self, region_id: int, level: int, region_type: RegionType): + """Initialize a new region. + + Args: + region_id: Unique identifier within the region hierarchy + level: Hierarchical level (0 = leaf, higher = more composite) + region_type: Type classification (LEAF, COMPOSITE, or ROOT) + """ + self.id = region_id + self.level = level + self.type = region_type + self.parent: Region | None = None + self.children: list[Region] = [] + self.nodes: set[int] = set() + self.inputs: list[str] = [] + self.outputs: list[str] = [] + self.metadata: dict[str, str] = {} + + def get_children(self, *, sort: bool = False) -> list["Region"]: + """Get all child regions. If sort is True, sort the children by level and size. + + Args: + sort: Whether to sort the children by level and size + + Returns: + List of child regions + """ + if sort: + return sorted( + self.children, key=lambda r: (-r.level, r.get_size_of_region_and_descendants()) + ) + return self.children + + def remove_child(self, child: "Region") -> bool: + """Remove a child region from this region's children list.""" + if child not in self.children: + return False + self.children.remove(child) + if child.parent and child.parent.id == self.id: + child.parent = None + return True + + def add_child(self, child: "Region") -> None: + """Add a child sub-region.""" + if child.id == self.id: + logger.warning(f"Cannot add region {self.id} as its own child") + return + + if self.is_descendant_of(child): + logger.warning( + f"Cycle detected: region {self.id} is already a descendant of region {child.id}" + ) + return + + if child.parent is not None and child.parent.id != self.id: + old_parent_id = child.parent.id + logger.debug( + f"Re-parenting region {child.id}: moving from parent {old_parent_id} to {self.id}" + ) + child.parent.remove_child(child) + + if any(c.id == child.id for c in self.children): + logger.debug(f"Region {child.id} already child of {self.id}") + return + + self.children.append(child) + child.parent = self + + def is_descendant_of(self, potential_ancestor: "Region") -> bool: + """Check if this region is a descendant of potential_ancestor.""" + visited = set() + current = self.parent + while current: + if current.id in visited: + return False + visited.add(current.id) + if current.id == potential_ancestor.id: + return True + current = current.parent + return False + + def get_nodes(self, *, sort: bool = False) -> list[int]: + """Get direct node indices in this region only.""" + if sort: + return sorted(self.nodes) + return list(self.nodes) + + def get_region_nodes_and_descendants(self, _visited: set[int] | None = None) -> set[int]: + """Get all node indices recursively, including descendants.""" + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, f"Cycle detected in region {self.id} during node traversal" + + _visited.add(self.id) + all_nodes = set(self.nodes) + for child in self.children: + all_nodes.update(child.get_region_nodes_and_descendants(_visited)) + return all_nodes + + def contains_node(self, node_index: int) -> bool: + """Check if region contains a specific node (direct only).""" + return node_index in self.nodes + + def contains_node_within_region_and_descendants(self, node_index: int) -> bool: + """Check if region contains a node recursively.""" + return node_index in self.get_region_nodes_and_descendants() + + def get_size_of_region_and_descendants(self, _visited: set[int] | None = None) -> int: + """Get total node count recursively including all descendants.""" + if _visited is None: + _visited = set() + + # Detect cycles + assert self.id not in _visited, ( + f"Cycle detected in region {self.id} during size calculation" + ) + + _visited.add(self.id) + total = len(self.nodes) + for child in self.children: + total += child.get_size_of_region_and_descendants(_visited) + return total + + def merge(self, other: "Region") -> None: + """Merge another region into this one.""" + if not other: + return + self.nodes.update(other.nodes) + for child in other.children: + self.add_child(child) + + def __repr__(self) -> str: + type_str = self.type.value + return ( + f"Region[id={self.id}, level={self.level}, type={type_str}, " + f"nodes={len(self.nodes)}, children={len(self.children)}, " + f"inputs={len(self.inputs)}, outputs={len(self.outputs)}]" + ) + + +@dataclass +class InsertionScheme: + """Complete Q/DQ insertion specification for a region pattern. + + An InsertionScheme defines a complete Q/DQ configuration for a pattern, + combining both node-level and region-level insertion points. The scheme + is applied to all regions matching the pattern. + """ + + node_inputs: list[NodeInputInsertionPoint] = field(default_factory=list) + child_region_inputs: list[ChildRegionInputInsertionPoint] = field(default_factory=list) + region_outputs: list[ChildRegionOutputInsertionPoint] = field(default_factory=list) + latency_ms: float = float("inf") + error: bool = False + profile_timestamp: str | None = None + + @property + def hash(self) -> str: + """Compute deterministic hash for scheme identity. + + The hash uniquely identifies this scheme configuration based on its + insertion points. Two schemes with identical insertion points produce + the same hash, regardless of their measured latencies. + """ + sorted_nodes = sorted([(pt.node_index, pt.input_index) for pt in self.node_inputs]) + sorted_regions = sorted( + [(pt.region_index, pt.input_index) for pt in self.child_region_inputs] + ) + sorted_region_outputs = sorted( + [(pt.region_index, pt.node_index, pt.output_index) for pt in self.region_outputs] + ) + + hash_input = f"{sorted_nodes}|{sorted_regions}|{sorted_region_outputs}" + + return hashlib.sha256(hash_input.encode("utf-8")).hexdigest()[:32] + + @property + def is_empty(self) -> bool: + """Check if this is a baseline scheme with no Q/DQ insertions.""" + return not self.node_inputs and not self.child_region_inputs and not self.region_outputs + + @property + def is_profiled(self) -> bool: + """Check if this scheme has been profiled (measured). + + A scheme is considered profiled if it has been measured (has non-infinite latency) + or has encountered an error during measurement. + """ + return self.error or self.latency_ms != float("inf") + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return { + "latency_ms": self.latency_ms, + "error": self.error, + "profile_timestamp": self.profile_timestamp, + "nodes_insertion_points": [pt.to_dict() for pt in self.node_inputs], + "child_region_inputs": [pt.to_dict() for pt in self.child_region_inputs], + "region_outputs": [pt.to_dict() for pt in self.region_outputs], + "hash": self.hash, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionScheme": + """Create InsertionScheme from serialized dictionary.""" + scheme = cls() + scheme.latency_ms = data.get("latency_ms", float("inf")) + scheme.error = data.get("error", False) + scheme.profile_timestamp = data.get("profile_timestamp") + + scheme.node_inputs = [ + NodeInputInsertionPoint.from_dict(pt) for pt in data.get("nodes_insertion_points", []) + ] + scheme.child_region_inputs = [ + ChildRegionInputInsertionPoint.from_dict(pt) + for pt in data.get("child_region_inputs", []) + ] + scheme.region_outputs = [ + ChildRegionOutputInsertionPoint.from_dict(pt) for pt in data.get("region_outputs", []) + ] + + return scheme + + def distance(self, other: "InsertionScheme") -> int: + """Compute edit distance between this scheme and another scheme. + + The edit distance is the minimum number of add/remove operations needed + to transform this scheme into the other scheme. This is computed as the + symmetric difference between the insertion point sets. + + Args: + other: InsertionScheme to compare against + + Returns: + Total edit distance (number of add + remove operations) + """ + return ( + len(set(self.node_inputs).symmetric_difference(other.node_inputs)) + + len(set(self.child_region_inputs).symmetric_difference(other.child_region_inputs)) + + len(set(self.region_outputs).symmetric_difference(other.region_outputs)) + ) + + def __str__(self) -> str: + """String representation for debugging.""" + error_str = ", error=True" if self.error else "" + return ( + f"InsertionScheme(node_insertions={len(self.node_inputs)}, " + f"region_insertions={len(self.child_region_inputs)}, " + f"region_output_insertions={len(self.region_outputs)}, " + f"latency={self.latency_ms:.3f}ms{error_str})" + ) diff --git a/modelopt/onnx/quantization/autotune/insertion_points.py b/modelopt/onnx/quantization/autotune/insertion_points.py new file mode 100644 index 000000000..dd01848dd --- /dev/null +++ b/modelopt/onnx/quantization/autotune/insertion_points.py @@ -0,0 +1,531 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Q/DQ insertion point management for ONNX quantization autotune. + +This module provides data structures and utilities for managing Quantization/Dequantization (Q/DQ) +insertion points in ONNX computational graphs during autotune optimization. It enables pattern-based +Q/DQ insertion that can be reused across multiple matching regions in a model. +""" + +from abc import ABC, abstractmethod +from dataclasses import asdict, dataclass +from typing import TYPE_CHECKING, Any + +import numpy as np +import onnx_graphsurgeon as gs + +if TYPE_CHECKING: + from modelopt.onnx.quantization.autotune.common import Region + +from modelopt.onnx.op_types import ( + get_aggregation_ops, + get_bitwise_ops, + get_bool_ops, + get_comparison_ops, + get_conditional_ops, + get_copy_ops, + get_set_ops, + get_value_check_ops, + is_fusible_reduction_op, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + + +class InsertionPoint(ABC): + """Abstract base class for pattern-relative Q/DQ insertion points.""" + + @abstractmethod + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + ... + + @classmethod + @abstractmethod + def from_dict(cls, data: dict[str, Any]) -> "InsertionPoint": + """Create from dictionary.""" + ... + + @abstractmethod + def resolve(self, region: "Region", graph: gs.Graph) -> set["ResolvedInsertionPoint"]: + """Resolve pattern-relative insertion point to actual tensor names.""" + ... + + @staticmethod + @abstractmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["InsertionPoint"]: + """Collect all valid insertion points of this type from a region.""" + ... + + +@dataclass(frozen=True) +class ResolvedInsertionPoint: + """Resolved Q/DQ insertion point with actual tensor name and optional node context. + + After resolving pattern-relative insertion points, this class represents the + actual location where Q/DQ pairs should be inserted in the graph. It contains the + tensor name and the node index (if applicable) and input index (if applicable). + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + tensor_name: str + node_index: int | None = None # Absolute graph node index (or None for tensor-level insertion) + input_index: int | None = None # Input tensor index of that node (or None) + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ResolvedInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + +@dataclass(frozen=True) +class NodeInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a node's input (frozen/hashable). + + Specifies where to insert a Q/DQ pair within a region pattern using + pattern-relative indices rather than absolute node IDs. This enables + insertion scheme reuse across all regions matching the same pattern. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + node_index: int # Pattern-relative node index + input_index: int # Input tensor index of that node + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "NodeInputInsertionPoint": + """Create from dictionary.""" + return cls(node_index=data["node_index"], input_index=data["input_index"]) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a node input insertion point to actual tensor names for a matching region.""" + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + actual_node_idx = node_indices[self.node_index] + node = graph.nodes[actual_node_idx] + assert self.input_index < len(node.inputs), "Input index out of range" + + resolved_ips = set() + # Determine which input indices to resolve (include weights for Conv/ConvTranspose) + input_indices = [self.input_index] + if node.op in ["Conv", "ConvTranspose"]: + assert self.input_index == 0, ( + "Conv/ConvTranspose inputs and weights must be quantized together" + ) + assert len(node.inputs) >= 2, "Conv/ConvTranspose should have at least 2 inputs" + input_indices.append(1) + + for idx in input_indices: + inp = node.inputs[idx] + if hasattr(inp, "name") and inp.name: + resolved_ips.add( + ResolvedInsertionPoint( + tensor_name=inp.name, node_index=actual_node_idx, input_index=idx + ) + ) + return resolved_ips + + @staticmethod + def collect_from_region(region: "Region", graph: gs.Graph) -> list["NodeInputInsertionPoint"]: + """Collect all valid node input insertion points from a region.""" + node_indices = region.get_nodes(sort=True) + insertion_points = [] + for local_idx, node_idx in enumerate(node_indices): + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + name = getattr(inp, "name", None) + if not name or skip_invalid_insertion_points(graph, name, node): + continue + insertion_points.append( + NodeInputInsertionPoint(node_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionInputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region's input boundary (frozen/hashable). + + Specifies where to insert Q/DQ pairs at the input boundaries of child regions + within COMPOSITE regions. This allows parent regions to control quantization + at child boundaries, potentially overriding or complementing child region + optimizations. + + Only applies to COMPOSITE regions; LEAF regions have no children. + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + # Pattern-relative child region index + region_index: int + # Input tensor index of that child region + input_index: int + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionInputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a child region input insertion point to actual tensor names.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return set() + + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Child region index out of range" + child_region = children_regions[self.region_index] + assert self.input_index < len(child_region.inputs), "Input index out of range" + tensor_name = child_region.inputs[self.input_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionInputInsertionPoint"]: + """Collect all valid child region input insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type == RegionType.LEAF: + return [] + + insertion_points = [] + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for input_idx, inp in enumerate(child_region.inputs): + if skip_invalid_insertion_points(graph, inp, child_region): + continue + insertion_points.append( + ChildRegionInputInsertionPoint(region_index=local_idx, input_index=input_idx) + ) + return insertion_points + + +@dataclass(frozen=True) +class ChildRegionOutputInsertionPoint(InsertionPoint): + """Pattern-relative Q/DQ insertion point at a child region or node output (frozen/hashable). + + Specifies where to insert Q/DQ pairs at output boundaries. This can be either: + 1. Output from a child region (in COMPOSITE regions) + 2. Output from a node within the region + + This class is immutable (frozen) to allow safe use in sets and as dict keys. + """ + + region_index: int | None # Pattern-relative child region index (or None) + node_index: int | None # Pattern-relative node index (or None) + output_index: int # Output tensor index + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary for serialization.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ChildRegionOutputInsertionPoint": + """Create from dictionary.""" + return cls(**data) + + def resolve(self, region: "Region", graph: gs.Graph) -> set[ResolvedInsertionPoint]: + """Resolve a region output insertion point to actual tensor names.""" + if self.region_index is not None: + children_regions = region.get_children(sort=True) + assert self.region_index < len(children_regions), "Region index out of range" + child_region = children_regions[self.region_index] + assert self.output_index < len(child_region.outputs), "Output index out of range" + tensor_name = child_region.outputs[self.output_index] + return resolve_region_io_insertion_points(child_region, graph, tensor_name) + + if self.node_index is not None: + node_indices = region.get_nodes(sort=True) + assert self.node_index < len(node_indices), "Node index out of range" + node = graph.nodes[node_indices[self.node_index]] + assert self.output_index < len(node.outputs), "Output index out of range" + tensor = node.outputs[self.output_index] + assert hasattr(tensor, "name") and tensor.name, "Tensor name is required" + return resolve_region_io_insertion_points(None, graph, tensor.name) + + return set() + + @staticmethod + def collect_from_region( + region: "Region", graph: gs.Graph + ) -> list["ChildRegionOutputInsertionPoint"]: + """Collect all valid region output insertion points from a region.""" + from modelopt.onnx.quantization.autotune.common import RegionType + + region_outputs_set = set(region.outputs) + insertion_points = [] + + # For COMPOSITE regions: collect child region outputs + if region.type != RegionType.LEAF: + for local_idx, child_region in enumerate(region.get_children(sort=True)): + for output_idx, out in enumerate(child_region.outputs): + if out in region_outputs_set and not skip_invalid_insertion_points( + graph, out, child_region + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=local_idx, node_index=None, output_index=output_idx + ) + ) + + # For all regions: collect node outputs + for local_idx, node_idx in enumerate(region.get_nodes(sort=True)): + node = graph.nodes[node_idx] + for output_idx, out in enumerate(node.outputs): + if not (hasattr(out, "name") and out.name): + continue + if out.name in region_outputs_set and not skip_invalid_insertion_points( + graph, out.name, node + ): + insertion_points.append( + ChildRegionOutputInsertionPoint( + region_index=None, node_index=local_idx, output_index=output_idx + ) + ) + + return insertion_points + + +def skip_invalid_insertion_points( + graph: gs.Graph, tensor_name: str, region_or_node: "Region | gs.Node" +) -> bool: + """Determine if a tensor should be skipped for Q/DQ insertion. + + Filters out tensors that are not suitable for quantization based on various criteria: + - Boolean and shape operations (not quantizable) + - Fused operation patterns (Conv->BatchNorm->ReLU) + - Operation-specific non-quantizable inputs (weights, biases, BN parameters) + - Non-floating-point tensors (indices, masks) + - Small tensors (scalars, small vectors with < 8 elements) + + Args: + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor to evaluate + region_or_node: Either a Region or a Node to check for usage of this tensor + + Returns: + True if the insertion point should be skipped, False if it's valid for quantization + """ + from modelopt.onnx.quantization.autotune.common import Region + + if isinstance(region_or_node, Region): + node_indices = region_or_node.get_region_nodes_and_descendants() + nodes: list[gs.Node] = [graph.nodes[node_idx] for node_idx in node_indices] + else: + assert isinstance(region_or_node, gs.Node) + nodes = [region_or_node] + + for node in nodes: + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + # Skip weights of Conv and ConvTranspose, they should be quantized with inputs at same time + if node.op in ["Conv", "ConvTranspose"] and input_idx >= 1: + return True + # Conv -> ReLU/Softmax or Conv -> BatchNormalization -> ReLU/Softmax + if node.op in ["Relu", "Softmax"]: + if len(node.inputs) == 1 and len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + if ( + producer.op == "BatchNormalization" + and len(producer.inputs[0].inputs) == 1 + and producer.inputs[0].inputs[0].op in ["Conv", "ConvTranspose"] + ): + return True + # Conv -> BatchNormalization + if node.op == "BatchNormalization": + assert len(node.inputs) >= 1, "BN node should have more than one inputs" + if len(node.inputs[0].inputs) == 1: + producer = node.inputs[0].inputs[0] + if producer.op in ["Conv", "ConvTranspose"]: + return True + # Filter 1: out boolean operations + if node.op in ( + get_bool_ops() + | get_bitwise_ops() + | get_value_check_ops() + | get_comparison_ops() + | get_conditional_ops() + | get_aggregation_ops() + | get_set_ops() + ) or is_fusible_reduction_op(node.op): + return True + # Filter 2: out shape operations + if node.op in get_autotuner_skip_ops(): + return True + # Filter 3: Skip operation-specific non-quantizable inputs + if node.op in ["BatchNormalization", "Resize"] and input_idx >= 1: + return True + if node.op in ["Conv", "Gemm"] and input_idx >= 2: + return True + # Filter 4: Skip non-floating-point tensors (int/bool indices, masks, etc.) + if hasattr(inp, "dtype") and inp.dtype not in [ + None, + np.float32, + np.float16, + np.float64, + ]: + return True + # Filter 5: Skip small tensors (scalars, small vectors) + if hasattr(inp, "shape") and inp.shape is not None: + if all(isinstance(s, int) for s in inp.shape): + if np.prod(inp.shape) < 8: + return True + return False + + +def has_quantizable_operations(region: "Region", graph: gs.Graph) -> bool: + """Check if a region contains major quantizable operations (only checks LEAF regions). + + Args: + region: The region to check + graph: The ONNX graph containing the nodes + + Returns: + True if the region contains major quantizable operations, False otherwise + """ + from modelopt.onnx.quantization.autotune.common import RegionType + + if region.type != RegionType.LEAF: + return True + region_ops = {graph.nodes[idx].op for idx in region.get_nodes()} + return bool(region_ops & get_autotuner_quantizable_ops()) + + +def resolve_region_io_insertion_points( + region: "Region | None", graph: gs.Graph, tensor_name: str +) -> set[ResolvedInsertionPoint]: + """Resolve region input/output boundaries to actual Q/DQ insertion points. + + For a given tensor at a region boundary (input or output), this function + identifies all the actual node inputs where Q/DQ pairs should be inserted. + It considers both nodes within the region (if provided) and all users of + the tensor in the graph. + + Args: + region: The region to search within (or None to search entire graph) + graph: The ONNX graph containing the nodes + tensor_name: Name of the tensor at the region boundary + + Returns: + Set of ResolvedInsertionPoint objects specifying where to insert Q/DQ pairs + """ + tensor_users_map = getattr(graph, "tensor_users_map", None) or get_tensor_consumer_node_indices( + graph + ) + + node_indices: set[int] = set() + if region is not None: + node_indices.update(region.get_region_nodes_and_descendants()) + node_indices.update(tensor_users_map.get(tensor_name, [])) + + resolved = set() + for node_idx in node_indices: + node = graph.nodes[node_idx] + for input_idx, inp in enumerate(node.inputs): + if hasattr(inp, "name") and inp.name == tensor_name: + if not skip_invalid_insertion_points(graph, tensor_name, node): + resolved.add( + ResolvedInsertionPoint( + tensor_name=tensor_name, node_index=node_idx, input_index=input_idx + ) + ) + return resolved + + +def merge_resolved_insertion_points( + graph: gs.Graph, resolved_insertion_points: set[ResolvedInsertionPoint] +) -> set[ResolvedInsertionPoint]: + """Optimize insertion points by merging node-specific insertions into tensor-level insertions. + + When all consumers (users) of a tensor have Q/DQ insertion points, it's more efficient + to insert Q/DQ once at the tensor level rather than at each individual node input. + This reduces the number of Q/DQ nodes in the graph and simplifies the quantization scheme. + + Args: + graph: The ONNX graph containing the nodes + resolved_insertion_points: Set of resolved insertion points to optimize + + Returns: + Optimized set of insertion points with merged tensor-level insertions where possible + """ + tensor_users_map = get_tensor_consumer_node_indices(graph) + node_ips = {ip for ip in resolved_insertion_points if ip.node_index is not None} + + results = resolved_insertion_points - node_ips + for tensor_name in {ip.tensor_name for ip in node_ips}: + all_users = set(tensor_users_map.get(tensor_name, [])) + qdq_users = {ip for ip in node_ips if ip.tensor_name == tensor_name} + if all_users == {ip.node_index for ip in qdq_users}: + results.add( + ResolvedInsertionPoint(tensor_name=tensor_name, node_index=None, input_index=None) + ) + else: + results.update(qdq_users) + return results + + +def get_autotuner_skip_ops(): + """Returns set of shape/structural operations that are not quantizable.""" + return set(get_copy_ops()) | { + # Additional indexing/scatter/reshape ops + "Compress", + "Scatter", + "ExpandDims", + "Unsqueeze", + "View", + "Pad", + # Utility ops + "Cast", + "Ceil", + "Clip", + "Identity", + "Range", + "Shape", + } + + +def get_autotuner_quantizable_ops(): + """Returns set of key operations that benefit from quantization.""" + return { + "Conv", + "ConvTranspose", + "Gemm", + "MatMul", + "AveragePool", + "MaxPool", + "GlobalAveragePool", + "GlobalMaxPool", + "Resize", + "Add", + "Sum", + "Mul", + "Relu", + } diff --git a/modelopt/onnx/quantization/graph_utils.py b/modelopt/onnx/quantization/graph_utils.py index 67596d5df..efa77dd7b 100755 --- a/modelopt/onnx/quantization/graph_utils.py +++ b/modelopt/onnx/quantization/graph_utils.py @@ -302,6 +302,28 @@ def get_tensor_consumer_nodes( return tensor_consumers +def get_tensor_consumer_node_indices(graph: onnx.GraphProto | gs.Graph) -> dict[str, list[int]]: + """Build a mapping from tensor names to the indices of nodes that use them. + + Args: + graph: ONNX GraphSurgeon graph to analyze + Returns: + Dictionary mapping tensor names to lists of node indices that consume them + """ + tensor_consumer_map: dict[str, list[int]] = defaultdict(list) + nodes = graph.nodes if isinstance(graph, gs.Graph) else graph.node + for node_idx, node in enumerate(nodes): + inputs = node.inputs if isinstance(node, gs.Node) else node.input + for tensor in inputs: + tensor_name = tensor + if isinstance(tensor, str): + tensor_name = tensor + elif hasattr(tensor, "name") and isinstance(tensor.name, str): + tensor_name = tensor.name + tensor_consumer_map[tensor_name].append(node_idx) + return tensor_consumer_map + + def filter_quantizable_kgen_heads( cask_fusible_partitions: list[list[Node]], kgen_partitions: list[list[Node]], diff --git a/tests/unit/onnx/quantization/autotune/test_insertion_points.py b/tests/unit/onnx/quantization/autotune/test_insertion_points.py new file mode 100644 index 000000000..2818d3172 --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_insertion_points.py @@ -0,0 +1,948 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive tests for common data structures in the autotuner. + +Tests: +1. InsertionPoint classes (NodeInputInsertionPoint, ChildRegionOutputInsertionPoint, ChildRegionInputInsertionPoint) +2. InsertionScheme serialization/deserialization +3. InsertionScheme hashing and equality +4. InsertionScheme properties and methods +5. PatternSchemes management +6. Utility functions (skip_invalid_insertion_points, has_quantizable_operations, etc.) +7. Resolve and collect_from methods for all InsertionPoint types +""" + +import unittest +from unittest.mock import MagicMock, patch + +import numpy as np +import onnx_graphsurgeon as gs +import pytest + +from modelopt.onnx.quantization.autotune.common import ( + ChildRegionInputInsertionPoint, + ChildRegionOutputInsertionPoint, + InsertionScheme, + NodeInputInsertionPoint, + Region, + RegionType, +) +from modelopt.onnx.quantization.autotune.insertion_points import ( + ResolvedInsertionPoint, + has_quantizable_operations, + merge_resolved_insertion_points, + resolve_region_io_insertion_points, + skip_invalid_insertion_points, +) +from modelopt.onnx.quantization.graph_utils import get_tensor_consumer_node_indices + +INSERTION_POINT_CASES = [ + pytest.param( + NodeInputInsertionPoint, + {"node_index": 5, "input_index": 2}, + {"node_index": 5, "input_index": 2}, + {"node_index": 5, "input_index": 3}, + "node_index", + ["5", "2"], + id="NodeInputInsertionPoint", + ), + pytest.param( + ChildRegionOutputInsertionPoint, + {"region_index": 2, "node_index": None, "output_index": 1}, + {"region_index": 2, "node_index": None, "output_index": 1}, + {"region_index": None, "node_index": 2, "output_index": 1}, + "region_index", + ["region", "2"], + id="ChildRegionOutputInsertionPoint-region", + ), + pytest.param( + ChildRegionOutputInsertionPoint, + {"region_index": None, "node_index": 5, "output_index": 0}, + {"region_index": None, "node_index": 5, "output_index": 0}, + {"region_index": None, "node_index": 5, "output_index": 1}, + "node_index", + ["node", "5"], + id="ChildRegionOutputInsertionPoint-node", + ), + pytest.param( + ChildRegionInputInsertionPoint, + {"region_index": 3, "input_index": 1}, + {"region_index": 3, "input_index": 1}, + {"region_index": 3, "input_index": 2}, + "region_index", + ["3", "1"], + id="ChildRegionInputInsertionPoint", + ), +] + + +class TestInsertionPoints: + """Combined tests for all InsertionPoint types.""" + + @pytest.mark.parametrize(("cls", "kwargs", "_", "__", "___", "____"), INSERTION_POINT_CASES) + def test_creation(self, cls, kwargs, _, __, ___, ____): + point = cls(**kwargs) + for key, val in kwargs.items(): + assert getattr(point, key) == val + + @pytest.mark.parametrize( + ("cls", "kwargs", "_", "__", "mutate_attr", "___"), INSERTION_POINT_CASES + ) + def test_immutability(self, cls, kwargs, _, __, mutate_attr, ___): + point = cls(**kwargs) + with pytest.raises(AttributeError): + setattr(point, mutate_attr, 999) + + @pytest.mark.parametrize( + ("cls", "kwargs", "equal_kwargs", "diff_kwargs", "_", "__"), INSERTION_POINT_CASES + ) + def test_equality(self, cls, kwargs, equal_kwargs, diff_kwargs, _, __): + point1 = cls(**kwargs) + point2 = cls(**equal_kwargs) + point3 = cls(**diff_kwargs) + assert point1 == point2 + assert point1 != point3 + + @pytest.mark.parametrize( + ("cls", "kwargs", "equal_kwargs", "diff_kwargs", "_", "__"), INSERTION_POINT_CASES + ) + def test_hashable(self, cls, kwargs, equal_kwargs, diff_kwargs, _, __): + point1 = cls(**kwargs) + point2 = cls(**equal_kwargs) + point3 = cls(**diff_kwargs) + point_set = {point1, point2, point3} + assert len(point_set) == 2 + + @pytest.mark.parametrize(("cls", "kwargs", "_", "__", "___", "____"), INSERTION_POINT_CASES) + def test_serialization(self, cls, kwargs, _, __, ___, ____): + point = cls(**kwargs) + data = point.to_dict() + for key, val in kwargs.items(): + assert data[key] == val + restored = cls.from_dict(data) + assert point == restored + + @pytest.mark.parametrize( + ("cls", "kwargs", "_", "__", "___", "str_checks"), INSERTION_POINT_CASES + ) + def test_string_representation(self, cls, kwargs, _, __, ___, str_checks): + point = cls(**kwargs) + s = str(point).lower() + for check in str_checks: + assert check.lower() in s + + +class TestInsertionScheme: + """Test InsertionScheme functionality.""" + + def test_empty_scheme(self): + """Test empty InsertionScheme.""" + scheme = InsertionScheme() + assert scheme.is_empty + assert len(scheme.node_inputs) == 0 + assert len(scheme.child_region_inputs) == 0 + assert len(scheme.region_outputs) == 0 + assert not scheme.error + + @pytest.mark.parametrize( + ("attr", "points"), + [ + ("node_inputs", [NodeInputInsertionPoint(0, 0), NodeInputInsertionPoint(1, 0)]), + ( + "region_outputs", + [ + ChildRegionOutputInsertionPoint(None, 0, 0), + ChildRegionOutputInsertionPoint(1, None, 0), + ], + ), + ( + "child_region_inputs", + [ChildRegionInputInsertionPoint(0, 0), ChildRegionInputInsertionPoint(1, 0)], + ), + ], + ) + def test_scheme_with_points_not_empty(self, attr, points): + """Test scheme with insertion points is not empty.""" + scheme = InsertionScheme() + setattr(scheme, attr, points) + assert not scheme.is_empty + assert len(getattr(scheme, attr)) == 2 + + def test_scheme_hash_empty(self): + """Test hash of empty schemes are equal.""" + assert InsertionScheme().hash == InsertionScheme().hash + + def test_scheme_hash_equality(self): + """Test hash with same/different insertion points.""" + + def make_scheme(*node_indices): + s = InsertionScheme() + s.node_inputs = [NodeInputInsertionPoint(i, 0) for i in node_indices] + return s + + assert make_scheme(0, 1).hash == make_scheme(0, 1).hash + assert make_scheme(0, 1).hash == make_scheme(1, 0).hash # order independent + assert make_scheme(0, 1).hash != make_scheme(0, 2).hash + + @pytest.mark.parametrize( + ("error", "latency"), + [ + (False, float("inf")), # empty + (False, 12.5), # full + (True, float("inf")), # with error + ], + ) + def test_serialization_roundtrip(self, error, latency): + """Test serialization roundtrip.""" + scheme = InsertionScheme() + scheme.error = error + scheme.latency_ms = latency + + if latency != float("inf") or error: # add points for non-empty cases + scheme.node_inputs = [NodeInputInsertionPoint(0, 0)] + scheme.child_region_inputs = [ChildRegionInputInsertionPoint(0, 0)] + scheme.region_outputs = [ChildRegionOutputInsertionPoint(None, 0, 0)] + + restored = InsertionScheme.from_dict(scheme.to_dict()) + + assert restored.error == error + assert restored.latency_ms == latency + if not scheme.is_empty: + assert len(restored.node_inputs) == len(scheme.node_inputs) + assert len(restored.child_region_inputs) == len(scheme.child_region_inputs) + assert len(restored.region_outputs) == len(scheme.region_outputs) + + +def _create_mock_tensor(name: str, dtype=np.float32, shape=None): + """Create a mock tensor with the specified properties.""" + tensor = MagicMock() + tensor.name = name + tensor.dtype = dtype + tensor.shape = shape if shape is not None else [1, 3, 224, 224] + tensor.inputs = [] + return tensor + + +def _create_mock_node(op: str, inputs: list, outputs: list, name: str = ""): + """Create a mock node with the specified properties.""" + node = MagicMock(spec=gs.Node) + node.op = op + node.name = name + node.inputs = inputs + node.outputs = outputs + return node + + +def _create_region(region_id=1, level=0, region_type=RegionType.LEAF, nodes=None): + """Create a region with the specified properties. + + Args: + region_id: ID for the region + level: Hierarchy level (0 for LEAF, 1+ for COMPOSITE/ROOT) + region_type: Type of region (LEAF, COMPOSITE, or ROOT) + nodes: Optional list/set of node indices to add to the region + + Returns: + Region with specified properties and nodes + """ + region = Region(region_id=region_id, level=level, region_type=region_type) + if nodes: + region.nodes.update(nodes) + return region + + +def _create_simple_graph(): + """Create a mock graph with Conv -> BatchNorm -> Relu -> MaxPool pattern. + + Graph structure: + input -> Conv -> conv_out -> BatchNorm -> bn_out -> Relu -> relu_out -> MaxPool -> pool_out + """ + # Create tensors with realistic shapes + input_tensor = _create_mock_tensor("input", np.float32, [1, 3, 224, 224]) + weight_tensor = _create_mock_tensor("conv_weight", np.float32, [64, 3, 3, 3]) + bias_tensor = _create_mock_tensor("conv_bias", np.float32, [64]) + conv_output = _create_mock_tensor("conv_out", np.float32, [1, 64, 222, 222]) + + # BatchNorm parameters + bn_scale = _create_mock_tensor("bn_scale", np.float32, [64]) + bn_bias = _create_mock_tensor("bn_bias", np.float32, [64]) + bn_mean = _create_mock_tensor("bn_mean", np.float32, [64]) + bn_var = _create_mock_tensor("bn_var", np.float32, [64]) + bn_output = _create_mock_tensor("bn_out", np.float32, [1, 64, 222, 222]) + + relu_output = _create_mock_tensor("relu_out", np.float32, [1, 64, 222, 222]) + pool_output = _create_mock_tensor("pool_out", np.float32, [1, 64, 111, 111]) + + # Create nodes + conv_node = _create_mock_node( + "Conv", [input_tensor, weight_tensor, bias_tensor], [conv_output], "conv1" + ) + bn_node = _create_mock_node( + "BatchNormalization", + [conv_output, bn_scale, bn_bias, bn_mean, bn_var], + [bn_output], + "bn1", + ) + relu_node = _create_mock_node("Relu", [bn_output], [relu_output], "relu1") + pool_node = _create_mock_node("MaxPool", [relu_output], [pool_output], "pool1") + + # Link tensors to their producer nodes + conv_output.inputs = [conv_node] + bn_output.inputs = [bn_node] + relu_output.inputs = [relu_node] + pool_output.inputs = [pool_node] + input_tensor.inputs = [] + weight_tensor.inputs = [] + bias_tensor.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv_node, bn_node, relu_node, pool_node] + graph.inputs = [input_tensor] + graph.outputs = [pool_output] + + tensors = { + "input": input_tensor, + "conv_weight": weight_tensor, + "conv_bias": bias_tensor, + "conv_out": conv_output, + "bn_out": bn_output, + "relu_out": relu_output, + "pool_out": pool_output, + } + + return graph, tensors + + +def _create_residual_graph(): + """Create a mock graph with a residual block pattern (skip connection). + + Graph structure: + input ─────────────────────────────┐ + │ │ + ▼ │ + Conv1 -> conv1_out │ + │ │ + ▼ │ + Relu1 -> relu1_out │ + │ │ + ▼ │ + Conv2 -> conv2_out │ + │ │ + ▼ ▼ + Add (conv2_out + input) -> add_out + │ + ▼ + Relu2 -> output + """ + # Create tensors + input_tensor = _create_mock_tensor("input", np.float32, [1, 64, 56, 56]) + + # First conv branch + weight1 = _create_mock_tensor("conv1_weight", np.float32, [64, 64, 3, 3]) + conv1_out = _create_mock_tensor("conv1_out", np.float32, [1, 64, 56, 56]) + relu1_out = _create_mock_tensor("relu1_out", np.float32, [1, 64, 56, 56]) + + # Second conv + weight2 = _create_mock_tensor("conv2_weight", np.float32, [64, 64, 3, 3]) + conv2_out = _create_mock_tensor("conv2_out", np.float32, [1, 64, 56, 56]) + + # Add and final relu + add_out = _create_mock_tensor("add_out", np.float32, [1, 64, 56, 56]) + output = _create_mock_tensor("output", np.float32, [1, 64, 56, 56]) + + # Create nodes + conv1_node = _create_mock_node("Conv", [input_tensor, weight1], [conv1_out], "conv1") + relu1_node = _create_mock_node("Relu", [conv1_out], [relu1_out], "relu1") + conv2_node = _create_mock_node("Conv", [relu1_out, weight2], [conv2_out], "conv2") + add_node = _create_mock_node("Add", [conv2_out, input_tensor], [add_out], "add1") + relu2_node = _create_mock_node("Relu", [add_out], [output], "relu2") + + # Link tensors to their producer nodes + conv1_out.inputs = [conv1_node] + relu1_out.inputs = [relu1_node] + conv2_out.inputs = [conv2_node] + add_out.inputs = [add_node] + output.inputs = [relu2_node] + input_tensor.inputs = [] + weight1.inputs = [] + weight2.inputs = [] + + # Create graph + graph = MagicMock(spec=gs.Graph) + graph.nodes = [conv1_node, relu1_node, conv2_node, add_node, relu2_node] + graph.inputs = [input_tensor] + graph.outputs = [output] + + tensors = { + "input": input_tensor, + "conv1_weight": weight1, + "conv1_out": conv1_out, + "relu1_out": relu1_out, + "conv2_weight": weight2, + "conv2_out": conv2_out, + "add_out": add_out, + "output": output, + } + + return graph, tensors + + +class TestSkipInvalidInsertionPoints: + """Test skip_invalid_insertion_points function.""" + + @pytest.mark.parametrize( + ("op", "should_skip"), + [ + ("Equal", True), # bool op + ("Shape", True), # shape op + ("MatMul", False), # normal op + ("Add", False), # normal op + ], + ) + def test_skip_by_op_type(self, op, should_skip): + graph, _ = _create_simple_graph() + tensor = _create_mock_tensor("test_input", np.float32, [1, 64, 32, 32]) + node = _create_mock_node(op, [tensor], []) + assert skip_invalid_insertion_points(graph, "test_input", node) is should_skip + + @pytest.mark.parametrize( + ("dtype", "shape", "should_skip"), + [ + (np.int32, [1, 64, 32, 32], True), # non-float + (np.float32, [1], True), # small tensor + (np.float32, [1, 64, 32, 32], False), # large float - OK + ], + ) + def test_skip_by_tensor_properties(self, dtype, shape, should_skip): + graph, _ = _create_simple_graph() + tensor = _create_mock_tensor("test", dtype, shape) + node = _create_mock_node("Add", [tensor], []) + assert skip_invalid_insertion_points(graph, "test", node) is should_skip + + def test_skip_conv_weight_input(self): + """Conv weight inputs (index >= 1) are skipped.""" + graph, _ = _create_simple_graph() + result = skip_invalid_insertion_points(graph, "conv_weight", graph.nodes[0]) + assert result is True + + def test_skip_bn_non_data_inputs(self): + """BatchNormalization non-data inputs are skipped.""" + graph, _ = _create_simple_graph() + result = skip_invalid_insertion_points(graph, "bn_scale", graph.nodes[1]) + assert result is True + + def test_skip_conv_bn_relu_fusion(self): + """Conv->BN->Relu fusion patterns are skipped at intermediate points.""" + graph, _ = _create_simple_graph() + result = skip_invalid_insertion_points(graph, "bn_out", graph.nodes[2]) + assert result is True + + def test_with_region(self): + """Test with a Region containing multiple nodes.""" + graph, _ = _create_simple_graph() + region = _create_region(nodes=[0, 1]) + + shape_tensor = _create_mock_tensor("shape_input", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], []) + graph.nodes.append(shape_node) + region.nodes.add(4) + + assert skip_invalid_insertion_points(graph, "shape_input", region) is True + + def test_residual_block_add_inputs_allowed(self): + """Add node inputs in residual blocks should be allowed.""" + graph, _ = _create_residual_graph() + add_node = graph.nodes[3] + + assert skip_invalid_insertion_points(graph, "conv2_out", add_node) is False + assert skip_invalid_insertion_points(graph, "input", add_node) is False + + +class TestHasQuantizableOperations: + """Test has_quantizable_operations function.""" + + @pytest.mark.parametrize( + ("nodes", "graph_fn", "expected"), + [ + ({0}, _create_simple_graph, True), # Conv + ({3}, _create_simple_graph, True), # MaxPool + ({2}, _create_simple_graph, True), # Relu + ({0, 1, 2}, _create_simple_graph, True), # Conv->BN->Relu + ({3}, _create_residual_graph, True), # Add in residual + ], + ) + def test_leaf_with_quantizable_ops(self, nodes, graph_fn, expected): + """Test LEAF region with various quantizable operations.""" + graph, _ = graph_fn() + region = _create_region(nodes=nodes) + assert has_quantizable_operations(region, graph) is expected + + def test_leaf_without_quantizable_ops(self): + """Test LEAF region without major quantizable operations.""" + shape_tensor = _create_mock_tensor("input", np.float32) + output_tensor = _create_mock_tensor("output", np.float32) + shape_node = _create_mock_node("Shape", [shape_tensor], [output_tensor]) + transpose_node = _create_mock_node("Transpose", [output_tensor], []) + graph = MagicMock(spec=gs.Graph) + graph.nodes = [shape_node, transpose_node] + region = _create_region(nodes={0, 1}) + + assert has_quantizable_operations(region, graph) is False + + def test_composite_region_always_true(self): + """Test that COMPOSITE regions always return True.""" + graph, _ = _create_simple_graph() + region = _create_region(level=1, region_type=RegionType.COMPOSITE) + assert has_quantizable_operations(region, graph) is True + + +class TestResolveRegionIOInsertionPoints(unittest.TestCase): + """Test resolve_region_io_insertion_points function.""" + + def test_resolve_with_region(self): + """Test resolving with a region containing Conv->BN->Relu.""" + graph, tensors = _create_simple_graph() + + # Set up tensor_users_map: conv_out is consumed by BatchNorm (node 1) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + region = _create_region(nodes=[2]) # Relu node + result = resolve_region_io_insertion_points(region, graph, "relu_out") + + assert len(result) >= 1 + assert any(ip.tensor_name == "relu_out" for ip in result) + + def test_resolve_without_region(self): + """Test resolving without a region (None) for tensor-level insertion.""" + graph, _ = _create_simple_graph() + + # Set up tensor_users_map: bn_out is consumed by Relu (node 2) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + result = resolve_region_io_insertion_points(None, graph, "relu_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "relu_out" + assert ip.node_index == 3 + assert ip.input_index == 0 + + def test_resolve_tensor_not_found(self): + """Test resolving a tensor that has no users.""" + graph, _ = _create_simple_graph() + graph.tensor_users_map = {} + result = resolve_region_io_insertion_points(None, graph, "nonexistent") + + assert len(result) == 0 + + def test_resolve_residual_skip_connection(self): + """Test resolving input tensor used by both Conv1 and Add (skip connection).""" + graph, tensors = _create_residual_graph() + + # Input tensor is used by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = {"input": [0, 3]} + result = resolve_region_io_insertion_points(None, graph, "input") + + # Should find both consumers + assert len(result) == 2 + node_indices = {ip.node_index for ip in result} + assert 0 in node_indices # Conv1 + assert 3 in node_indices # Add + + def test_resolve_with_multiple_consumers(self): + """Test resolving tensor with multiple consumers in a region.""" + graph, tensors = _create_residual_graph() + + # relu1_out feeds conv2 (node 2) + graph.tensor_users_map = {"relu1_out": [2]} + + region = _create_region(nodes=[2]) # Conv2 + + result = resolve_region_io_insertion_points(region, graph, "relu1_out") + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "relu1_out" + assert ip.node_index == 2 + + +class TestMergeResolvedInsertionPoints(unittest.TestCase): + """Test merge_resolved_insertion_points function.""" + + def test_merge_all_users(self): + """Test merging when all users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by BatchNorm (node 1) + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "conv_out" + assert merged.node_index is None + assert merged.input_index is None + + def test_no_merge_partial_users(self): + """Test no merging when only some users have insertion points.""" + graph, _ = _create_simple_graph() + + # Setup: tensor "conv_out" is used by nodes 1 and 2, but only node 1 has IP + resolved = { + ResolvedInsertionPoint(tensor_name="conv_out", node_index=1, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1, 2]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT be merged - keep node-specific + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 1 # Still node-specific + + def test_preserve_tensor_level_insertions(self): + """Test that existing tensor-level insertions are preserved.""" + graph, _ = _create_simple_graph() + + # Already tensor-level insertion + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=None, input_index=None), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"conv_out": [1]} + + result = merge_resolved_insertion_points(graph, resolved) + + assert len(result) == 1 + ip = next(iter(result)) + assert ip.tensor_name == "input" + assert ip.node_index is None + + def test_merge_residual_skip_connection(self): + """Test merging with residual block where input has two users.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # If we have insertion points for both, they should merge + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + ResolvedInsertionPoint(tensor_name="input", node_index=3, input_index=1), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should be merged to tensor-level insertion + assert len(result) == 1 + merged = next(iter(result)) + assert merged.tensor_name == "input" + assert merged.node_index is None + + def test_no_merge_residual_partial(self): + """Test no merging in residual block when only one branch has insertion point.""" + graph, _ = _create_residual_graph() + + # Input tensor used by Conv1 (node 0) and Add (node 3) + # Only Conv1 has an insertion point + resolved = { + ResolvedInsertionPoint(tensor_name="input", node_index=0, input_index=0), + } + + with patch( + "modelopt.onnx.quantization.autotune.insertion_points.get_tensor_consumer_node_indices" + ) as mock_get: + mock_get.return_value = {"input": [0, 3]} + + result = merge_resolved_insertion_points(graph, resolved) + + # Should NOT merge - only one of two users has IP + assert len(result) == 1 + ip = next(iter(result)) + assert ip.node_index == 0 # Still node-specific + + +class TestNodeInputInsertionPointMethods(unittest.TestCase): + """Test NodeInputInsertionPoint.resolve() and collect_from_region() methods.""" + + def test_resolve_simple(self): + """Test resolving a simple node input for Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + + # Create insertion point for first input of first node (Conv) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + result = ip.resolve(region, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_conv_includes_weight(self): + """Test that resolving Conv input also includes weight.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0]) # Conv node + + # Create insertion point for first input of Conv (should also add weight) + ip = NodeInputInsertionPoint(node_index=0, input_index=0) + result = ip.resolve(region, graph) + + # Should include both data input and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "input" in tensor_names + assert "conv_weight" in tensor_names + + def test_resolve_relu_input(self): + """Test resolving Relu input in the middle of the chain.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + + # Relu is at local index 2, input 0 is bn_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + result = ip.resolve(region, graph) + + assert len(result) == 1 + rip = next(iter(result)) + assert rip.tensor_name == "bn_out" + + def test_resolve_residual_conv_input(self): + """Test resolving Conv input in residual block.""" + graph, tensors = _create_residual_graph() + region = _create_region(nodes=[0, 1, 2]) # Conv1, Relu1, Conv2 + + # Conv2 is at local index 2, input 0 is relu1_out + ip = NodeInputInsertionPoint(node_index=2, input_index=0) + result = ip.resolve(region, graph) + + # Conv includes both data and weight + assert len(result) == 2 + tensor_names = {rip.tensor_name for rip in result} + assert "relu1_out" in tensor_names + assert "conv2_weight" in tensor_names + + def test_collect_valid_inputs(self): + """Test collecting valid node input insertion points from Conv->BN->Relu->Pool.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected some insertion points + assert len(result) >= 1 + # All should be NodeInputInsertionPoint + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + def test_collect_from_residual_block(self): + """Test collecting from residual block with skip connection.""" + graph, tensors = _create_residual_graph() + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 + result = NodeInputInsertionPoint.collect_from_region(region, graph) + + # Should have collected insertion points from Conv1, Add inputs, etc. + assert len(result) >= 1 + assert all(isinstance(ip, NodeInputInsertionPoint) for ip in result) + + # Check that we have insertion points for different nodes + node_indices = {ip.node_index for ip in result} + assert len(node_indices) >= 1 # At least one node has valid inputs + + +class TestChildRegionInputInsertionPointMethods(unittest.TestCase): + """Test ChildRegionInputInsertionPoint.resolve() and collect_from_region() methods.""" + + def test_resolve_composite_region(self): + """Test resolving child region input in COMPOSITE region.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"input": [0]} + + # Create parent (COMPOSITE) with child (LEAF) containing Conv->BN->Relu + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.inputs = ["input"] + parent.add_child(child) + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result = ip.resolve(parent, graph) + + assert len(result) >= 1 + assert any(rip.tensor_name == "input" for rip in result) + + def test_resolve_leaf_returns_empty(self): + """Test that LEAF regions return empty set.""" + graph, _ = _create_simple_graph() + leaf = _create_region(nodes=[0]) + ip = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result = ip.resolve(leaf, graph) + assert len(result) == 0 + + def test_resolve_multiple_children(self): + """Test resolving child inputs in COMPOSITE with multiple children.""" + graph, tensors = _create_residual_graph() + # input is consumed by Conv1 (node 0) and Add (node 3) + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + + # Create parent with two child regions + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + + # First child: Conv1 (consumes "input") + child1 = _create_region(region_id=2, nodes=[0]) # Conv1 + child1.inputs = ["input"] + + # Second child: Relu1 (consumes "relu1_out") + child2 = _create_region(region_id=3, nodes=[2]) # Relu1 + child2.inputs = ["relu1_out"] + parent.add_child(child1) + parent.add_child(child2) + + # Resolve input of first child (region_index=0) - "input" tensor + ip1 = ChildRegionInputInsertionPoint(region_index=0, input_index=0) + result1 = ip1.resolve(parent, graph) + + assert len(result1) >= 1 + assert any(rip.tensor_name == "input" for rip in result1) + + # Resolve input of second child (region_index=1) - "relu1_out" tensor + ip2 = ChildRegionInputInsertionPoint(region_index=1, input_index=0) + result2 = ip2.resolve(parent, graph) + + assert len(result2) >= 1 + assert any(rip.tensor_name == "relu1_out" for rip in result2) + + def test_collect_from_composite(self): + """Test collecting from COMPOSITE region with children.""" + graph, tensors = _create_simple_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.inputs = ["input"] + parent.add_child(child) + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + # Should find the child's input + assert len(result) >= 0 # May be filtered by skip_invalid_insertion_points + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + def test_collect_from_leaf_returns_empty(self): + """Test that LEAF regions return empty list.""" + graph, _ = _create_simple_graph() + leaf = _create_region(nodes=[0]) + result = ChildRegionInputInsertionPoint.collect_from_region(leaf, graph) + assert len(result) == 0 + + def test_collect_from_composite_with_multiple_children(self): + """Test collecting from COMPOSITE with multiple child regions.""" + graph, tensors = _create_residual_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = _create_region(region_id=2, nodes=[0, 1]) # Conv1, Relu1 + child1.inputs = ["input"] + child2 = _create_region(region_id=3, nodes=[2, 3]) # Conv2, Add + child2.inputs = ["relu1_out", "input"] # Two inputs including skip connection + parent.add_child(child1) + parent.add_child(child2) + + result = ChildRegionInputInsertionPoint.collect_from_region(parent, graph) + # Should find inputs from both children + assert all(isinstance(ip, ChildRegionInputInsertionPoint) for ip in result) + + +class TestChildRegionOutputInsertionPointMethods(unittest.TestCase): + """Test ChildRegionOutputInsertionPoint.resolve() and collect_from_region() methods.""" + + def test_resolve_node_output(self): + """Test resolving a node output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = get_tensor_consumer_node_indices(graph) + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + region.outputs = ["pool_out"] + # Output of last node (MaxPool) + ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=2, output_index=0) + result = ip.resolve(region, graph) + assert len(result) >= 1 + assert any(rip.tensor_name == "relu_out" for rip in result) + + def test_resolve_child_region_output(self): + """Test resolving a child region output.""" + graph, tensors = _create_simple_graph() + graph.tensor_users_map = {"relu_out": [3]} + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.outputs = ["relu_out"] + parent.add_child(child) + ip = ChildRegionOutputInsertionPoint(region_index=0, node_index=None, output_index=0) + result = ip.resolve(parent, graph) + assert len(result) >= 1 + assert any(rip.tensor_name == "relu_out" for rip in result) + + def test_resolve_residual_add_output(self): + """Test resolving Add output in residual block.""" + graph, tensors = _create_residual_graph() + graph.tensor_users_map = {"add_out": [4]} + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 + region.outputs = ["add_out"] + # Add is at local index 3, output 0 + ip = ChildRegionOutputInsertionPoint(region_index=None, node_index=3, output_index=0) + result = ip.resolve(region, graph) + assert len(result) >= 1 + assert any(rip.tensor_name == "add_out" for rip in result) + + def test_collect_node_outputs(self): + """Test collecting node output insertion points.""" + graph, tensors = _create_simple_graph() + region = _create_region(nodes=[0, 1, 2, 3]) # Conv, BatchNorm, Relu, MaxPool + region.outputs = ["pool_out"] # Only pool_out is a region output + result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the node output that matches region output + assert len(result) >= 0 # May be filtered + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) + + def test_collect_child_region_outputs(self): + """Test collecting child region output insertion points.""" + graph, tensors = _create_simple_graph() + parent = _create_region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = _create_region(region_id=2, nodes=[0, 1, 2]) # Conv, BatchNorm, Relu + child.outputs = ["relu_out"] + parent.add_child(child) + parent.outputs = ["relu_out"] # Child output is also parent output + result = ChildRegionOutputInsertionPoint.collect_from_region(parent, graph) + + # Should find the child region output + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) + + def test_collect_residual_block_outputs(self): + """Test collecting outputs from residual block.""" + graph, tensors = _create_residual_graph() + region = _create_region(nodes=[0, 1, 2, 3, 4]) # Conv1, Relu1, Conv2, Add, Relu2 + region.outputs = ["output"] # Final output + result = ChildRegionOutputInsertionPoint.collect_from_region(region, graph) + + # Should find the output + assert all(isinstance(ip, ChildRegionOutputInsertionPoint) for ip in result) diff --git a/tests/unit/onnx/quantization/autotune/test_region.py b/tests/unit/onnx/quantization/autotune/test_region.py new file mode 100644 index 000000000..a27b1c98c --- /dev/null +++ b/tests/unit/onnx/quantization/autotune/test_region.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for the Region class in the autotuner.""" + +import pytest + +from modelopt.onnx.quantization.autotune.common import Region, RegionType + + +@pytest.fixture +def leaf(): + return Region(region_id=1, level=0, region_type=RegionType.LEAF) + + +@pytest.fixture +def parent_with_children(): + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child1 = Region(region_id=2, level=0, region_type=RegionType.LEAF) + child2 = Region(region_id=3, level=0, region_type=RegionType.LEAF) + parent.add_child(child1) + parent.add_child(child2) + return parent, child1, child2 + + +@pytest.mark.parametrize( + ("region_id", "level", "region_type"), + [ + (1, 0, RegionType.LEAF), + (2, 1, RegionType.COMPOSITE), + (0, 2, RegionType.ROOT), + ], +) +def test_region_creation(region_id, level, region_type): + region = Region(region_id=region_id, level=level, region_type=region_type) + assert (region.id, region.level, region.type) == (region_id, level, region_type) + + +def test_parent_child_relationship(parent_with_children): + parent, child1, child2 = parent_with_children + assert parent.get_children() == [child1, child2] + assert child1.parent == child2.parent == parent + + +def test_add_and_get_nodes(leaf): + leaf.nodes.update([0, 1, 2]) + assert set(leaf.get_nodes()) == {0, 1, 2} + + +def test_input_output_tensors(leaf): + leaf.inputs = ["in1", "in2"] + leaf.outputs = ["out1"] + assert leaf.inputs == ["in1", "in2"] + assert leaf.outputs == ["out1"] + + +def test_region_size_recursive(parent_with_children): + parent, child1, child2 = parent_with_children + child1.nodes.update([0, 1]) + child2.nodes.update([2, 3, 4]) + parent.nodes.add(5) + assert len(parent.get_region_nodes_and_descendants()) == 6 + + +def test_metadata(leaf): + leaf.metadata.update({"pattern": "Conv->Relu", "quantizable": "true"}) + assert leaf.metadata == {"pattern": "Conv->Relu", "quantizable": "true"} + + +def test_hierarchical_structure(): + root = Region(region_id=0, level=2, region_type=RegionType.ROOT) + comp1 = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + comp2 = Region(region_id=2, level=1, region_type=RegionType.COMPOSITE) + leaves = [Region(region_id=i, level=0, region_type=RegionType.LEAF) for i in range(3, 6)] + root.add_child(comp1) + root.add_child(comp2) + comp1.add_child(leaves[0]) + comp1.add_child(leaves[1]) + comp2.add_child(leaves[2]) + for i, leaf in enumerate(leaves): + leaf.nodes.add(i) + assert len(root.get_children()) == 2 + assert len(comp1.get_children()) == 2 + assert len(comp2.get_children()) == 1 + assert len(root.get_region_nodes_and_descendants()) == 3 + + +def test_remove_child(): + parent = Region(region_id=1, level=1, region_type=RegionType.COMPOSITE) + child = Region(region_id=2, level=0, region_type=RegionType.LEAF) + parent.add_child(child) + parent.remove_child(child) + assert parent.get_children() == [] + assert child.parent is None