Skip to content

Commit 57816c9

Browse files
committed
Add boilerplate code
ghstack-source-id: ab8565eb263a52424f59f9cae71688803ec0b5de Pull Request resolved: #1635
1 parent abd41e5 commit 57816c9

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

Diff for: test/test_utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def __init__(self, data):
4141
self.data = data
4242

4343
l = torch.nn.Linear(10, 10)
44-
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
44+
with self.assertRaisesRegex(NotImplementedError, "Subclasses must implement"):
4545
l.weight = torch.nn.Parameter(MyTensor(l.weight))
4646

47+
assert MyTensor(l.weight).get_layout() is None
48+
4749

4850
if __name__ == "__main__":
4951
unittest.main()

Diff for: torchao/utils.py

+67-3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
import torch.nn.utils.parametrize as parametrize
12+
from torch.utils._python_dispatch import return_and_correct_aliasing
1213

1314
__all__ = [
1415
"benchmark_model",
@@ -512,9 +513,15 @@ def _get_tensor_impl_constructor(
512513
return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class]
513514

514515

516+
"""
517+
TensorAOBase subclass is a util tensor subclass that provides commonly used functions, and should be inherited to define a new tensor subclass
518+
"""
519+
520+
515521
class TorchAOBaseTensor(torch.Tensor):
516522
"""A util tensor subclass that provides commonly used functions
517-
new tensor subclass can inherit it to get all the utility functions
523+
new tensor subclass can inherit it to get all the utility functions, and
524+
should be inherited to define a new tensor subclass
518525
519526
class MyTensor(TorchAOBaseTensor):
520527
pass
@@ -539,12 +546,17 @@ def _(func, types, args, kwargs):
539546
class PlainAQTTensorImpl(...):
540547
...
541548
542-
`get_tensor_impl_constructor`:
549+
`get_tensor_impl_constructor`:
543550
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor
544551
# in constructor of MyTensor:
545552
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
546553
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
547554
555+
`__tensor_flatten__` and `__tensor_unflatten__`: Used for tensor serialization/deserialization, must be defined in the subclass
556+
557+
`__repr__`: Used for tensor representation, must be defined in the subclass
558+
559+
`_apply_fn_to_data`: Used for applying a function to the data of the tensor, must be defined in the subclass
548560
"""
549561

550562
implements = classmethod(_implements)
@@ -573,6 +585,58 @@ def _get_to_kwargs(self, *args, **kwargs):
573585
}
574586
return kwargs
575587

588+
def __tensor_flatten__(self):
589+
raise NotImplementedError("Subclasses must implement __tensor_flatten__")
590+
591+
@classmethod
592+
def __tensor_unflatten__(
593+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
594+
):
595+
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")
596+
597+
def __repr__(self):
598+
raise NotImplementedError("Subclasses must implement __repr__")
599+
600+
def _apply_fn_to_data(self, fn):
601+
raise NotImplementedError("Subclasses must implement _apply_fn_to_data")
602+
603+
def get_layout(self):
604+
if not hasattr(self, "_layout"):
605+
return None
606+
return self._layout
607+
608+
609+
implements = TorchAOBaseTensor.implements
610+
aten = torch.ops.aten
611+
612+
613+
@implements(aten.detach.default)
614+
def _(func, types, args, kwargs):
615+
return return_and_correct_aliasing(
616+
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
617+
)
618+
619+
620+
@implements(aten.clone.default)
621+
def _(func, types, args, kwargs):
622+
return return_and_correct_aliasing(
623+
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
624+
)
625+
626+
627+
@implements(aten.t.default)
628+
def _(func, types, args, kwargs):
629+
return return_and_correct_aliasing(
630+
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
631+
)
632+
633+
634+
@implements(aten.slice.Tensor)
635+
def _(func, types, args, kwargs):
636+
return return_and_correct_aliasing(
637+
func, args, kwargs, args[0]._apply_fn_to_data(torch.slice)
638+
)
639+
576640

577641
def fill_defaults(args, n, defaults_tail):
578642
"""
@@ -599,7 +663,7 @@ def fill_defaults(args, n, defaults_tail):
599663
return r
600664

601665

602-
## Deprecated, will be deleted in the future
666+
# Deprecated, will be deleted in the future
603667
def _torch_version_at_least(min_version):
604668
return is_fbcode() or version("torch") >= min_version
605669

0 commit comments

Comments
 (0)