Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,91 @@ def fake_linear(func, types, args, kwargs):
counter["calls"], 2, "Expected fake_linear to be called via aten.t.default"
)

def test_subclassing(self):
class Parent(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr"]

Parent._ATEN_OP_TABLE[Parent]["op_parent"] = "parent_impl"
Parent._TORCH_FN_TABLE[Parent]["fn_parent"] = "parent_fn_impl"

class Child(Parent):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr"]

# ensure child has copied parent ops
self.assertEqual(Child._ATEN_OP_TABLE[Child]["op_parent"], "parent_impl")
self.assertEqual(Child._TORCH_FN_TABLE[Child]["fn_parent"], "parent_fn_impl")

# ensure the top-level dicts are distinct (not inherited)
self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE)
self.assertIsNot(Parent._TORCH_FN_TABLE, Child._TORCH_FN_TABLE)

# change the parent's op after subclass creation — should not leak
Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later"
self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child])

def test_subclassing_with_real_op(self):
counter = {"calls": 0}

class Parent(TorchAOBaseTensor):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr"]

def __new__(cls, qdata, attr):
r = torch.Tensor._make_wrapper_subclass(cls, qdata.shape)
r.qdata = qdata
r.attr = attr
return r

def __init__(self, qdata, attr):
pass

# Real op implementation
@Parent.implements([torch.ops.aten.cat.default])
def _cat_op(func, types, args, kwargs):
counter["calls"] += 1

class Child(Parent):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr"]

# Table checks
self.assertIn(torch.ops.aten.cat.default, Parent._ATEN_OP_TABLE[Parent])
self.assertIn(torch.ops.aten.cat.default, Child._ATEN_OP_TABLE[Child])

# Ensure child table is distinct
self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE)

# calling the op through the child tensor
t1 = torch.randn(2, 3)
t2 = torch.randn(2, 3)
child_tensor1 = Child(t1, "a")
child_tensor2 = Child(t2, "b")

torch.ops.aten.cat.default([child_tensor1, child_tensor2], 0)

self.assertEqual(counter["calls"], 1)

def test_multiple_inheritance(self):
class A(TorchAOBaseTensor):
tensor_data_names = ["a"]
tensor_attribute_names = ["b"]

class B(TorchAOBaseTensor):
tensor_data_names = ["a"]
tensor_attribute_names = ["b"]

A._ATEN_OP_TABLE[A]["shared"] = "from_a"
B._ATEN_OP_TABLE[B]["shared"] = "from_b"

class C(A, B):
tensor_data_names = ["a"]
tensor_attribute_names = ["b"]

# C(A, B) should inherit from A then B, so B wins
self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b")


if __name__ == "__main__":
unittest.main()
16 changes: 10 additions & 6 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,9 +828,9 @@ def __init__(

@classmethod
def __init_subclass__(cls, **kwargs):
if not hasattr(cls, "_ATEN_OP_TABLE"):
if "_ATEN_OP_TABLE" not in cls.__dict__:
cls._ATEN_OP_TABLE = {}
if not hasattr(cls, "_TORCH_FN_TABLE"):
if "_TORCH_FN_TABLE" not in cls.__dict__:
cls._TORCH_FN_TABLE = {}
if cls not in cls._ATEN_OP_TABLE:
cls._ATEN_OP_TABLE[cls] = {}
Expand All @@ -846,10 +846,14 @@ def __init_subclass__(cls, **kwargs):
# inherit the torch function and dispatch implementations from direct parent classes
# e.g. for `class C(B, A)`, C.__bases__ == (B, A)
for parent in cls.__bases__:
if hasattr(cls, "_ATEN_OP_TABLE") and parent in cls._ATEN_OP_TABLE:
cls._ATEN_OP_TABLE[cls].update(cls._ATEN_OP_TABLE[parent])
if hasattr(cls, "_TORCH_FN_TABLE") and parent in cls._TORCH_FN_TABLE:
cls._TORCH_FN_TABLE[cls].update(cls._TORCH_FN_TABLE[parent])
parent_aten_table = getattr(parent, "_ATEN_OP_TABLE", None)
if parent_aten_table and parent in parent_aten_table:
# shallow-copy parent's per-class op mapping into child's per-class mapping
cls._ATEN_OP_TABLE[cls].update(parent_aten_table[parent])

parent_torch_table = getattr(parent, "_TORCH_FN_TABLE", None)
if parent_torch_table and parent in parent_torch_table:
cls._TORCH_FN_TABLE[cls].update(parent_torch_table[parent])

implements = classmethod(_implements)
implements_torch_function = classmethod(_implements_torch_function)
Expand Down