Skip to content

Commit d42c725

Browse files
committed
Update
[ghstack-poisoned]
1 parent fb30b81 commit d42c725

File tree

2 files changed

+59
-23
lines changed

2 files changed

+59
-23
lines changed

test/test_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ def __init__(self, data):
4141
self.data = data
4242

4343
l = torch.nn.Linear(10, 10)
44-
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
44+
with self.assertRaisesRegex(NotImplementedError, "Subclasses must implement"):
4545
l.weight = torch.nn.Parameter(MyTensor(l.weight))
4646

47+
assert MyTensor(l.weight).get_layout() is None
48+
4749

4850
if __name__ == "__main__":
4951
unittest.main()

torchao/utils.py

Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import torch
1111
import torch.nn.utils.parametrize as parametrize
12+
from torch.utils._python_dispatch import return_and_correct_aliasing
1213

1314
__all__ = [
1415
"benchmark_model",
@@ -512,9 +513,36 @@ def _get_tensor_impl_constructor(
512513
return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class]
513514

514515

516+
def _get_to_kwargs(self, *args, **kwargs):
517+
# `torch._C._nn._parse_to` can't handle `layout` argument
518+
for arg in args:
519+
if isinstance(arg, torch.layout):
520+
args.remove(arg)
521+
if "layout" in kwargs:
522+
kwargs.pop("layout")
523+
# ignoring `non_blocking` and `memory_format` args since these are not
524+
# very useful for most of the tensor subclasses
525+
# if in the future there are use cases that need these, we'd recommend
526+
# to override `_get_to_kwargs` and return these args
527+
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
528+
device = self.device if device is None else device
529+
dtype = self.dtype if dtype is None else dtype
530+
kwargs = {
531+
"device": device,
532+
"dtype": dtype,
533+
}
534+
return kwargs
535+
536+
537+
"""
538+
TensorAOBase subclass is a util tensor subclass that provides commonly used functions, and should be inherited to define a new tensor subclass
539+
"""
540+
541+
515542
class TorchAOBaseTensor(torch.Tensor):
516543
"""A util tensor subclass that provides commonly used functions
517-
new tensor subclass can inherit it to get all the utility functions
544+
new tensor subclass can inherit it to get all the utility functions, and
545+
should be inherited to define a new tensor subclass
518546
519547
class MyTensor(TorchAOBaseTensor):
520548
pass
@@ -539,39 +567,45 @@ def _(func, types, args, kwargs):
539567
class PlainAQTTensorImpl(...):
540568
...
541569
542-
`get_tensor_impl_constructor`:
570+
`get_tensor_impl_constructor`:
543571
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor
544572
# in constructor of MyTensor:
545573
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
546574
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
547575
576+
`__tensor_flatten__` and `__tensor_unflatten__`: Used for tensor serialization/deserialization, must be defined in the subclass
577+
578+
`__repr__`: Used for tensor representation, must be defined in the subclass
579+
580+
`_apply_fn_to_data`: Used for applying a function to the data of the tensor, must be defined in the subclass
548581
"""
549582

550583
implements = classmethod(_implements)
551584
__torch_dispatch__ = classmethod(_dispatch__torch_dispatch__)
552585
__torch_function__ = classmethod(_dispatch__torch_function__)
553586
register_layout = classmethod(_register_layout)
554587
get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor)
588+
_get_to_kwargs = _get_to_kwargs
589+
590+
def __tensor_flatten__(self):
591+
raise NotImplementedError("Subclasses must implement __tensor_flatten__")
592+
593+
@classmethod
594+
def __tensor_unflatten__(
595+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
596+
):
597+
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")
598+
599+
def __repr__(self):
600+
raise NotImplementedError("Subclasses must implement __repr__")
601+
602+
def _apply_fn_to_data(self, fn):
603+
raise NotImplementedError("Subclasses must implement _apply_fn_to_data")
555604

556-
def _get_to_kwargs(self, *args, **kwargs):
557-
# `torch._C._nn._parse_to` can't handle `layout` argument
558-
for arg in args:
559-
if isinstance(arg, torch.layout):
560-
args.remove(arg)
561-
if "layout" in kwargs:
562-
kwargs.pop("layout")
563-
# ignoring `non_blocking` and `memory_format` args since these are not
564-
# very useful for most of the tensor subclasses
565-
# if in the future there are use cases that need these, we'd recommend
566-
# to override `_get_to_kwargs` and return these args
567-
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
568-
device = self.device if device is None else device
569-
dtype = self.dtype if dtype is None else dtype
570-
kwargs = {
571-
"device": device,
572-
"dtype": dtype,
573-
}
574-
return kwargs
605+
def get_layout(self):
606+
if not hasattr(self, "_layout"):
607+
return None
608+
return self._layout
575609

576610

577611
def fill_defaults(args, n, defaults_tail):
@@ -599,7 +633,7 @@ def fill_defaults(args, n, defaults_tail):
599633
return r
600634

601635

602-
## Deprecated, will be deleted in the future
636+
# Deprecated, will be deleted in the future
603637
def _torch_version_at_least(min_version):
604638
return is_fbcode() or version("torch") >= min_version
605639

0 commit comments

Comments
 (0)