diff --git a/modelopt/onnx/autocast/graphsanitizer.py b/modelopt/onnx/autocast/graphsanitizer.py index 85f407a59..97c987f36 100644 --- a/modelopt/onnx/autocast/graphsanitizer.py +++ b/modelopt/onnx/autocast/graphsanitizer.py @@ -331,6 +331,7 @@ def _match_layernorm_pattern(self, mean_node: onnx.NodeProto) -> dict | None: # Find and extract scale and bias nodes if present scale = None bias = None + scale_dimension = None final_node = div_node nodes_to_remove = [ mean_node, diff --git a/modelopt/onnx/autocast/nodeclassifier.py b/modelopt/onnx/autocast/nodeclassifier.py index 0a7638429..828f522e2 100644 --- a/modelopt/onnx/autocast/nodeclassifier.py +++ b/modelopt/onnx/autocast/nodeclassifier.py @@ -286,6 +286,11 @@ def _check_inner(self, node): axis_array = onnx.numpy_helper.to_array(axis_init) assert axis_array.ndim == 0 or (axis_array.ndim == 1 and axis_array.size == 1) axis = int(axis_array.item()) + # Normalize negative axis and check bounds + if axis < 0: + axis = len(input_0_dims) + axis + if axis < 0 or axis >= len(input_0_dims): + return False if input_0_dims[axis] > self.max_depth_of_reduction: self.reduction_depth = input_0_dims[axis] @@ -449,18 +454,22 @@ def run(self, ref_outputs_dict=None): """ exclude_node_rules = self._gen_exclude_node_rules(ref_outputs_dict) include_node_rules = self._gen_include_node_rules() - low_precision_nodes = self.custom_ops_low_precision_nodes or [] + # Use a set to avoid duplicates + low_precision_nodes_set = set(self.custom_ops_low_precision_nodes or []) high_precision_nodes = [] for node in self.model.graph.node: + # Skip if already classified as low precision + if node.name in low_precision_nodes_set: + continue # If any condition is met - node will be executed in high precision - if ( - node.name not in low_precision_nodes - and any(rule.check(node) for rule in exclude_node_rules) - and not any(rule.check(node) for rule in include_node_rules) + if any(rule.check(node) for rule in exclude_node_rules) and not any( + rule.check(node) for rule in include_node_rules ): high_precision_nodes.append(node.name) else: - low_precision_nodes.append(node.name) + low_precision_nodes_set.add(node.name) + # Convert back to list for return value + low_precision_nodes = list(low_precision_nodes_set) logger.debug(f"Low Precision Nodes: {low_precision_nodes}") logger.debug(f"High Precision Nodes: {high_precision_nodes}") return low_precision_nodes, high_precision_nodes