@@ -722,6 +722,79 @@ def test_calculate_numeric_gap(self):
722722 # gap should equal 3.0
723723 self .assertEqual (row ["gap" ][0 ], 3.0 )
724724
725+ def test_calculate_numeric_gap_with_custom_comparator (self ):
726+ """Test calculate_numeric_gap with a custom NumericalComparatorBase implementation."""
727+ from executorch .devtools .inspector .numerical_comparator import (
728+ NumericalComparatorBase ,
729+ )
730+
731+ # Create a custom comparator that returns the max absolute difference
732+ class MaxAbsDiffComparator (NumericalComparatorBase ):
733+ def compare (self , a , b ):
734+ if isinstance (a , torch .Tensor ) and isinstance (b , torch .Tensor ):
735+ return torch .max (torch .abs (a - b )).item ()
736+ return abs (a - b )
737+
738+ # Create a context manager to patch functions called by Inspector.__init__
739+ with patch .object (
740+ _inspector , "parse_etrecord" , return_value = None
741+ ), patch .object (
742+ _inspector , "gen_etdump_object" , return_value = None
743+ ), patch .object (
744+ EventBlock , "_gen_from_etdump"
745+ ), patch .object (
746+ _inspector , "gen_graphs_from_etrecord"
747+ ):
748+ # Call the constructor of Inspector
749+ inspector_instance = Inspector (
750+ etdump_path = ETDUMP_PATH ,
751+ etrecord = ETRECORD_PATH ,
752+ )
753+
754+ aot_intermediate_outputs = {
755+ (0 ,): torch .tensor ([1.0 , 2.0 , 3.0 ]),
756+ (1 ,): torch .tensor ([4.0 , 5.0 , 6.0 ]),
757+ }
758+
759+ runtime_intermediate_outputs = {
760+ (0 ,): ([torch .tensor ([2.0 , 1.0 , 5.0 ])], 1 ),
761+ (1 ,): ([torch .tensor ([3.0 , 6.0 , 5.0 ])], 1 ),
762+ }
763+
764+ aot_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
765+ runtime_debug_handle_to_op_name = {(0 ,): "op_0" , (1 ,): "op_1" }
766+
767+ inspector_instance ._get_aot_intermediate_outputs_and_op_names = lambda x : (
768+ aot_intermediate_outputs ,
769+ aot_debug_handle_to_op_name ,
770+ )
771+ inspector_instance ._get_runtime_intermediate_outputs_and_op_names = (
772+ lambda : (runtime_intermediate_outputs , runtime_debug_handle_to_op_name )
773+ )
774+
775+ # Create custom comparator instance
776+ custom_comparator = MaxAbsDiffComparator ()
777+
778+ # Test with custom comparator
779+ df = inspector_instance .calculate_numeric_gap (distance = custom_comparator )
780+ self .assertIsInstance (df , pd .DataFrame )
781+ self .assertEqual (len (df ), 2 )
782+ cols = set (df .columns )
783+ expected_cols = {
784+ "aot_ops" ,
785+ "aot_intermediate_output" ,
786+ "runtime_ops" ,
787+ "runtime_intermediate_output" ,
788+ "gap" ,
789+ }
790+ self .assertEqual (cols , expected_cols )
791+
792+ # Verify the custom comparator logic
793+ # For (0,): max(|[1.0, 2.0, 3.0] - [2.0, 1.0, 5.0]|) = max([1.0, 1.0, 2.0]) = 2.0
794+ self .assertEqual (df .iloc [0 ]["gap" ][0 ], 2.0 )
795+ # For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0
796+ self .assertEqual (df .iloc [1 ]["gap" ][0 ], 1.0 )
797+
725798 @unittest .skip ("ci config values are not propagated" )
726799 def test_intermediate_tensor_comparison_with_torch_export (self ):
727800 """Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower."""
0 commit comments