Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,49 @@ def fake_linear(func, types, args, kwargs):
counter["calls"], 2, "Expected fake_linear to be called via aten.t.default"
)

def test_subclassing(self):
class Parent(TorchAOBaseTensor):
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the test should be doing this:

1. define Parent class
2. Parent class implements some method

# can do a real implementation or just some dummy one
@Parent.implements(aten.cat)
def _(..):
   ...
   
   
3. some child class inherits from parent: Child(Parent)

4. make sure the same op still works
by calling the op (e.g. aten.cat) and make sure it works (make sure it's called)

can do some additional table checks

tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr"]

Parent._ATEN_OP_TABLE[Parent]["op_parent"] = "parent_impl"
Parent._TORCH_FN_TABLE[Parent]["fn_parent"] = "parent_fn_impl"

class Child(Parent):
tensor_data_names = ["qdata"]
tensor_attribute_names = ["attr"]

# ensure child has copied parent ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think idea is the same, but just changing it to real op implementation with implements and implements_torch_function would be better

self.assertEqual(Child._ATEN_OP_TABLE[Child]["op_parent"], "parent_impl")
self.assertEqual(Child._TORCH_FN_TABLE[Child]["fn_parent"], "parent_fn_impl")

# ensure the top-level dicts are distinct (not inherited)
self.assertIsNot(Parent._ATEN_OP_TABLE, Child._ATEN_OP_TABLE)
self.assertIsNot(Parent._TORCH_FN_TABLE, Child._TORCH_FN_TABLE)

# change the parent's op after subclass creation — should not leak
Parent._ATEN_OP_TABLE[Parent]["new_op"] = "added_later"
self.assertNotIn("new_op", Child._ATEN_OP_TABLE[Child])

def test_multiple_inheritance(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah testing multiple inheritance is helpful as well, thanks

class A(TorchAOBaseTensor):
tensor_data_names = ["a"]
tensor_attribute_names = ["b"]

class B(TorchAOBaseTensor):
tensor_data_names = ["a"]
tensor_attribute_names = ["b"]

A._ATEN_OP_TABLE[A]["shared"] = "from_a"
B._ATEN_OP_TABLE[B]["shared"] = "from_b"

class C(A, B):
tensor_data_names = ["a"]
tensor_attribute_names = ["b"]

# C(A, B) should inherit from A then B, so B wins
self.assertEqual(C._ATEN_OP_TABLE[C]["shared"], "from_b")


if __name__ == "__main__":
unittest.main()