diff --git a/test/test_utils.py b/test/test_utils.py index 0e77388f13..3d316accc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -342,6 +342,91 @@ def fake_linear(func, types, args, kwargs): counter["calls"], 2, "Expected fake_linear to be called via aten.t.default" ) + def test_subclassing(self): + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + Parent._ATEN_OP_TABLE[Parent]["op_parent"] = "parent_impl" + Parent._TORCH_FN_TABLE[Parent]["fn_parent"] = "parent_fn_impl" + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # ensure child has copied parent ops + self.assertEqual(Child._ATEN_OP_TABLE[Child]["op_parent"], "parent_impl") + self.assertEqual(Child._TORCH_FN_TABLE[Child]["fn_parent"], "parent_fn_impl") + + # ensure the top-level dicts are distinct (not inherited) + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + self.assertIsNot(Parent._TORCH_FN_TABLE, Child._TORCH_FN_TABLE) + + # change the parent's op after subclass creation — should not leak + Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later" + self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child]) + + def test_subclassing_with_real_op(self): + counter = {"calls": 0} + + class Parent(TorchAOBaseTensor): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + def __new__(cls, qdata, attr): + r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape) + r.qdata = qdata + r.attr = attr + return r + + def __init__(self, qdata, attr): + pass + + # Real op implementation + @Parent.implements([torch.ops.aten.cat.default]) + def _cat_op(func, types, args, kwargs): + counter["calls"] += 1 + + class Child(Parent): + tensor_data_names = ["qdata"] + tensor_attribute_names = ["attr"] + + # Table checks + self.assertIn(torch.ops.aten.cat.default, Parent._ATEN_OP_TABLE[Parent]) + self.assertIn(torch.ops.aten.cat.default, Child._ATEN_OP_TABLE[Child]) + + # Ensure child table is distinct + self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE) + + # calling the op through the child tensor + t1 = torch.randn(2, 3) + t2 = torch.randn(2, 3) + child_tensor1 = Child(t1, "a") + child_tensor2 = Child(t2, "b") + + torch.ops.aten.cat.default([child_tensor1, child_tensor2], 0) + + self.assertEqual(counter["calls"], 1) + + def test_multiple_inheritance(self): + class A(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + class B(TorchAOBaseTensor): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + A._ATEN_OP_TABLE[A]["shared"] = "from_a" + B._ATEN_OP_TABLE[B]["shared"] = "from_b" + + class C(A, B): + tensor_data_names = ["a"] + tensor_attribute_names = ["b"] + + # C(A, B) should inherit from A then B, so B wins + self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b") + if __name__ == "__main__": unittest.main() diff --git a/torchao/utils.py b/torchao/utils.py index 875383a064..38dd0f2949 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -828,9 +828,9 @@ def __init__( @classmethod def __init_subclass__(cls, **kwargs): - if not hasattr(cls, "_ATEN_OP_TABLE"): + if "_ATEN_OP_TABLE" not in cls.__dict__: cls._ATEN_OP_TABLE = {} - if not hasattr(cls, "_TORCH_FN_TABLE"): + if "_TORCH_FN_TABLE" not in cls.__dict__: cls._TORCH_FN_TABLE = {} if cls not in cls._ATEN_OP_TABLE: cls._ATEN_OP_TABLE[cls] = {} @@ -846,10 +846,14 @@ def __init_subclass__(cls, **kwargs): # inherit the torch function and dispatch implementations from direct parent classes # e.g. for `class C(B, A)`, C.__bases__ == (B, A) for parent in cls.__bases__: - if hasattr(cls, "_ATEN_OP_TABLE") and parent in cls._ATEN_OP_TABLE: - cls._ATEN_OP_TABLE[cls].update(cls._ATEN_OP_TABLE[parent]) - if hasattr(cls, "_TORCH_FN_TABLE") and parent in cls._TORCH_FN_TABLE: - cls._TORCH_FN_TABLE[cls].update(cls._TORCH_FN_TABLE[parent]) + parent_aten_table = getattr(parent, "_ATEN_OP_TABLE", None) + if parent_aten_table and parent in parent_aten_table: + # shallow-copy parent's per-class op mapping into child's per-class mapping + cls._ATEN_OP_TABLE[cls].update(parent_aten_table[parent]) + + parent_torch_table = getattr(parent, "_TORCH_FN_TABLE", None) + if parent_torch_table and parent in parent_torch_table: + cls._TORCH_FN_TABLE[cls].update(parent_torch_table[parent]) implements = classmethod(_implements) implements_torch_function = classmethod(_implements_torch_function)