Skip to content

Commit 1864fb0

Browse files
authored
custom comparator support in calculate_numeric_gap
Differential Revision: D87840446 Pull Request resolved: #15969
1 parent ba26536 commit 1864fb0

File tree

4 files changed

+97
-39
lines changed

4 files changed

+97
-39
lines changed

devtools/inspector/_inspector.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@
7373
from executorch.devtools.inspector.numerical_comparator import (
7474
L1Comparator,
7575
MSEComparator,
76+
NumericalComparatorBase,
7677
SNRComparator,
7778
)
7879
from executorch.exir import ExportedProgram
@@ -1404,7 +1405,9 @@ def get_exported_program(
14041405
)
14051406

14061407
def calculate_numeric_gap(
1407-
self, distance: str = "MSE", disable_debug_handle_valdiation: bool = False
1408+
self,
1409+
distance: Union[str, NumericalComparatorBase],
1410+
disable_debug_handle_valdiation: bool = False,
14081411
):
14091412
"""
14101413
Compares logged intermediate outputs from the exported graph (in ETRecord)
@@ -1416,7 +1419,10 @@ def calculate_numeric_gap(
14161419
compare the intermediate outputs from the AOT and the runtime.
14171420
14181421
Args:
1419-
distance: The metrics the inspector will use for gap calculation. Should be one of "MSE", "L1" and "SNR".
1422+
distance: The metrics the inspector will use for gap calculation. Can be either:
1423+
- A string: one of "MSE", "L1", or "SNR" for built-in comparators.
1424+
- A custom NumericalComparatorBase instance: allows you to define custom comparison logic
1425+
by subclassing NumericalComparatorBase and implementing the compare() method.
14201426
disable_debug_handle_validation: Often when aten graph has symbolic shape nodes and inbuilt ops like gt/lt etc.,
14211427
during re-export of such a graph 'from_node' information is lost from node.meta. As a result we loose
14221428
connection between edge IR nodes and aten nodes for such ops. By default we validate that every edge IR
@@ -1442,15 +1448,18 @@ def calculate_numeric_gap(
14421448
mapping = map_runtime_aot_intermediate_outputs(
14431449
aot_intermediate_outputs, runtime_intermediate_outputs
14441450
)
1445-
metric = distance.strip().upper()
1446-
if metric == "MSE":
1447-
comparator = MSEComparator()
1448-
elif metric == "L1":
1449-
comparator = L1Comparator()
1450-
elif metric == "SNR":
1451-
comparator = SNRComparator()
1451+
if isinstance(distance, NumericalComparatorBase):
1452+
comparator = distance
14521453
else:
1453-
raise ValueError(f"Unsupported distance metric {distance!r}")
1454+
metric = distance.strip().upper()
1455+
if metric == "MSE":
1456+
comparator = MSEComparator()
1457+
elif metric == "L1":
1458+
comparator = L1Comparator()
1459+
elif metric == "SNR":
1460+
comparator = SNRComparator()
1461+
else:
1462+
raise ValueError(f"Unsupported distance metric {distance!r}")
14541463

14551464
rows = []
14561465
for (aot_debug_handle, aot_intermediate_output), (

devtools/inspector/numerical_comparator/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,13 @@
1313
MSEComparator,
1414
)
1515

16+
from executorch.devtools.inspector.numerical_comparator.numerical_comparator_base import (
17+
NumericalComparatorBase,
18+
)
19+
1620
from executorch.devtools.inspector.numerical_comparator.snr_numerical_comparator import (
1721
SNRComparator,
1822
)
1923

2024

21-
__all__ = ["L1Comparator", "MSEComparator", "SNRComparator"]
25+
__all__ = ["L1Comparator", "MSEComparator", "SNRComparator", "NumericalComparatorBase"]

devtools/inspector/numerical_comparator/inspector_numerical_comparator_base.py

Lines changed: 0 additions & 28 deletions
This file was deleted.

devtools/inspector/tests/inspector_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,79 @@ def test_calculate_numeric_gap(self):
722722
# gap should equal 3.0
723723
self.assertEqual(row["gap"][0], 3.0)
724724

725+
def test_calculate_numeric_gap_with_custom_comparator(self):
726+
"""Test calculate_numeric_gap with a custom NumericalComparatorBase implementation."""
727+
from executorch.devtools.inspector.numerical_comparator import (
728+
NumericalComparatorBase,
729+
)
730+
731+
# Create a custom comparator that returns the max absolute difference
732+
class MaxAbsDiffComparator(NumericalComparatorBase):
733+
def compare(self, a, b):
734+
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
735+
return torch.max(torch.abs(a - b)).item()
736+
return abs(a - b)
737+
738+
# Create a context manager to patch functions called by Inspector.__init__
739+
with patch.object(
740+
_inspector, "parse_etrecord", return_value=None
741+
), patch.object(
742+
_inspector, "gen_etdump_object", return_value=None
743+
), patch.object(
744+
EventBlock, "_gen_from_etdump"
745+
), patch.object(
746+
_inspector, "gen_graphs_from_etrecord"
747+
):
748+
# Call the constructor of Inspector
749+
inspector_instance = Inspector(
750+
etdump_path=ETDUMP_PATH,
751+
etrecord=ETRECORD_PATH,
752+
)
753+
754+
aot_intermediate_outputs = {
755+
(0,): torch.tensor([1.0, 2.0, 3.0]),
756+
(1,): torch.tensor([4.0, 5.0, 6.0]),
757+
}
758+
759+
runtime_intermediate_outputs = {
760+
(0,): ([torch.tensor([2.0, 1.0, 5.0])], 1),
761+
(1,): ([torch.tensor([3.0, 6.0, 5.0])], 1),
762+
}
763+
764+
aot_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
765+
runtime_debug_handle_to_op_name = {(0,): "op_0", (1,): "op_1"}
766+
767+
inspector_instance._get_aot_intermediate_outputs_and_op_names = lambda x: (
768+
aot_intermediate_outputs,
769+
aot_debug_handle_to_op_name,
770+
)
771+
inspector_instance._get_runtime_intermediate_outputs_and_op_names = (
772+
lambda: (runtime_intermediate_outputs, runtime_debug_handle_to_op_name)
773+
)
774+
775+
# Create custom comparator instance
776+
custom_comparator = MaxAbsDiffComparator()
777+
778+
# Test with custom comparator
779+
df = inspector_instance.calculate_numeric_gap(distance=custom_comparator)
780+
self.assertIsInstance(df, pd.DataFrame)
781+
self.assertEqual(len(df), 2)
782+
cols = set(df.columns)
783+
expected_cols = {
784+
"aot_ops",
785+
"aot_intermediate_output",
786+
"runtime_ops",
787+
"runtime_intermediate_output",
788+
"gap",
789+
}
790+
self.assertEqual(cols, expected_cols)
791+
792+
# Verify the custom comparator logic
793+
# For (0,): max(|[1.0, 2.0, 3.0] - [2.0, 1.0, 5.0]|) = max([1.0, 1.0, 2.0]) = 2.0
794+
self.assertEqual(df.iloc[0]["gap"][0], 2.0)
795+
# For (1,): max(|[4.0, 5.0, 6.0] - [3.0, 6.0, 5.0]|) = max([1.0, 1.0, 1.0]) = 1.0
796+
self.assertEqual(df.iloc[1]["gap"][0], 1.0)
797+
725798
@unittest.skip("ci config values are not propagated")
726799
def test_intermediate_tensor_comparison_with_torch_export(self):
727800
"""Test intermediate tensor comparison using torch.export.export and to_edge_transform_and_lower."""

0 commit comments

Comments
 (0)