Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions modelopt/onnx/autocast/graphsanitizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None:
# Find and extract scale and bias nodes if present
scale = None
bias = None
scale_dimension = None
final_node = div_node
nodes_to_remove = [
mean_node,
Expand Down
21 changes: 15 additions & 6 deletions modelopt/onnx/autocast/nodeclassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ def _check_inner(self, node):
axis_array = onnx.numpy_helper.to_array(axis_init)
assert axis_array.ndim == 0 or (axis_array.ndim == 1 and axis_array.size == 1)
axis = int(axis_array.item())
# Normalize negative axis and check bounds
if axis < 0:
axis = len(input_0_dims) + axis
if axis < 0 or axis >= len(input_0_dims):
return False
if input_0_dims[axis] > self.max_depth_of_reduction:
self.reduction_depth = input_0_dims[axis]

Expand Down Expand Up @@ -449,18 +454,22 @@ def run(self, ref_outputs_dict=None):
"""
exclude_node_rules = self._gen_exclude_node_rules(ref_outputs_dict)
include_node_rules = self._gen_include_node_rules()
low_precision_nodes = self.custom_ops_low_precision_nodes or []
# Use a set to avoid duplicates
low_precision_nodes_set = set(self.custom_ops_low_precision_nodes or [])
high_precision_nodes = []
for node in self.model.graph.node:
# Skip if already classified as low precision
if node.name in low_precision_nodes_set:
continue
# If any condition is met - node will be executed in high precision
if (
node.name not in low_precision_nodes
and any(rule.check(node) for rule in exclude_node_rules)
and not any(rule.check(node) for rule in include_node_rules)
if any(rule.check(node) for rule in exclude_node_rules) and not any(
rule.check(node) for rule in include_node_rules
):
high_precision_nodes.append(node.name)
else:
low_precision_nodes.append(node.name)
low_precision_nodes_set.add(node.name)
# Convert back to list for return value
low_precision_nodes = list(low_precision_nodes_set)
logger.debug(f"Low Precision Nodes: {low_precision_nodes}")
logger.debug(f"High Precision Nodes: {high_precision_nodes}")
return low_precision_nodes, high_precision_nodes