From 9a6f47a19cc624311500fcd755e5dad30b073a50 Mon Sep 17 00:00:00 2001 From: Krishn Parasar Date: Sat, 11 Oct 2025 03:18:40 +0530 Subject: [PATCH 1/3] testing subclassing --- test/test_utils.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 0e77388f13..9b1c8c87d9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -342,6 +342,49 @@ def fake_linear(func, types, args, kwargs): counter["calls"], 2, "Expected fake_linear to be called via aten.t.default" ) + def test_subclassing(self): + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + Parent._ATEN_OP_TABLE[Parent]["op_parent"] = "parent_impl" + Parent._TORCH_FN_TABLE[Parent]["fn_parent"] = "parent_fn_impl" + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # ensure child has copied parent ops + self.assertEqual(Child._ATEN_OP_TABLE[Child]["op_parent"], "parent_impl") + self.assertEqual(Child._TORCH_FN_TABLE[Child]["fn_parent"], "parent_fn_impl") + + # ensure the top-level dicts are distinct (not inherited) + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + self.assertIsNot(Parent._TORCH_FN_TABLE, Child._TORCH_FN_TABLE) + + # change the parent's op after subclass creation — should not leak + Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later" + self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child]) + + def test_multiple_inheritance(self): + class A(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + class B(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + A._ATEN_OP_TABLE[A]["shared"] = "from_a" + B._ATEN_OP_TABLE[B]["shared"] = "from_b" + + class C(A, B): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + # C(A, B) should inherit from A then B, so B wins + self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b") + if __name__ == "__main__": unittest.main() From 95fc1bc14f12ff8069771fc0e3f2936a6e9715da Mon Sep 17 00:00:00 2001 From: Krishn Parasar Date: Sat, 11 Oct 2025 20:52:11 +0530 Subject: [PATCH 2/3] Adding test with real op implementation the new test with real op fails for the same line self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE). Also when it is called using the child tensor, it works. For the inheritance test case, it fails at self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b") AssertionError: 'from_a' != 'from_b' --- test/test_utils.py | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 9b1c8c87d9..21661137ea 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -366,6 +366,49 @@ class Child(Parent): Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later" self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child]) + def test_subclassing_with_real_op(self): + counter = {"calls": 0} + + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + def __new__(cls, qdata, attr): + r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape) + r.qdata = qdata + r.attr = attr + return r + + def __init__(self, qdata, attr): + pass + + # Real op implementation + @Parent.implements([torch.ops.aten.cat.default]) + def _cat_op(func, types, args, kwargs): + counter["calls"] += 1 + return func(*args, **kwargs) + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # Table checks + self.assertIn(torch.ops.aten.cat.default, Parent._ATEN_OP_TABLE[Parent]) + self.assertIn(torch.ops.aten.cat.default, Child._ATEN_OP_TABLE[Child]) + + # Ensure child table is distinct + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + + # calling the op through the child tensor + t1 = torch.randn(2, 3) + t2 = torch.randn(2, 3) + child_tensor1 = Child(t1, "a") + child_tensor2 = Child(t2, "b") + + torch.ops.aten.cat.default([child_tensor1, child_tensor2], 0) + + self.assertEqual(counter["calls"], 1) + def test_multiple_inheritance(self): class A(TorchAOBaseTensor): tensor_data_names = ["a"] From 2402e3732d2592efeb720e2452b78007e7878e28 Mon Sep 17 00:00:00 2001 From: Krishn Parasar Date: Thu, 27 Nov 2025 20:14:44 +0530 Subject: [PATCH 3/3] Fixing subclassing Fix subclass initialization by ensuring each subclass gets its own per-class dispatch tables and correctly inherits parent op mappings. --- test/test_utils.py | 1 - torchao/utils.py | 16 ++++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index 21661137ea..3d316accc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -386,7 +386,6 @@ def __init__(self, qdata, attr): @Parent.implements([torch.ops.aten.cat.default]) def _cat_op(func, types, args, kwargs): counter["calls"] += 1 - return func(*args, **kwargs) class Child(Parent): tensor_data_names = ["qdata"] diff --git a/torchao/utils.py b/torchao/utils.py index 875383a064..38dd0f2949 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -828,9 +828,9 @@ def __init__( @classmethod def __init_subclass__(cls, **kwargs): - if not hasattr(cls, "_ATEN_OP_TABLE"): + if "_ATEN_OP_TABLE" not in cls.__dict__: cls._ATEN_OP_TABLE = {} - if not hasattr(cls, "_TORCH_FN_TABLE"): + if "_TORCH_FN_TABLE" not in cls.__dict__: cls._TORCH_FN_TABLE = {} if cls not in cls._ATEN_OP_TABLE: cls._ATEN_OP_TABLE[cls] = {} @@ -846,10 +846,14 @@ def __init_subclass__(cls, **kwargs): # inherit the torch function and dispatch implementations from direct parent classes # e.g. for `class C(B, A)`, C.__bases__ == (B, A) for parent in cls.__bases__: - if hasattr(cls, "_ATEN_OP_TABLE") and parent in cls._ATEN_OP_TABLE: - cls._ATEN_OP_TABLE[cls].update(cls._ATEN_OP_TABLE[parent]) - if hasattr(cls, "_TORCH_FN_TABLE") and parent in cls._TORCH_FN_TABLE: - cls._TORCH_FN_TABLE[cls].update(cls._TORCH_FN_TABLE[parent]) + parent_aten_table = getattr(parent, "_ATEN_OP_TABLE", None) + if parent_aten_table and parent in parent_aten_table: + # shallow-copy parent's per-class op mapping into child's per-class mapping + cls._ATEN_OP_TABLE[cls].update(parent_aten_table[parent]) + + parent_torch_table = getattr(parent, "_TORCH_FN_TABLE", None) + if parent_torch_table and parent in parent_torch_table: + cls._TORCH_FN_TABLE[cls].update(parent_torch_table[parent]) implements = classmethod(_implements) implements_torch_function = classmethod(_implements_torch_function)